anyantudre's picture
moved from training repo to inference
caa56d6
#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates. All Rights Reserved.
"""Video models."""
import torch
import torch.nn as nn
import copy
import slowfast.utils.weight_init_helper as init_helper
from slowfast.models.batchnorm_helper import get_norm
from . import head_helper, resnet_helper, stem_helper
from .build import MODEL_REGISTRY
# Number of blocks for different stages given the model depth.
_MODEL_STAGE_DEPTH = {18:(2,2,2,2),50: (3, 4, 6, 3), 101: (3, 4, 23, 3)}
# Basis of temporal kernel sizes for each of the stage.
_TEMPORAL_KERNEL_BASIS = {
"c2d": [
[[1]], # conv1 temporal kernel.
[[1]], # res2 temporal kernel.
[[1]], # res3 temporal kernel.
[[1]], # res4 temporal kernel.
[[1]], # res5 temporal kernel.
],
"c2d_nopool": [
[[1]], # conv1 temporal kernel.
[[1]], # res2 temporal kernel.
[[1]], # res3 temporal kernel.
[[1]], # res4 temporal kernel.
[[1]], # res5 temporal kernel.
],
"i3d": [
[[5]], # conv1 temporal kernel.
[[3]], # res2 temporal kernel.
[[3, 1]], # res3 temporal kernel.
[[3, 1]], # res4 temporal kernel.
[[1, 3]], # res5 temporal kernel.
],
"r3d_18": [
[[3]], # conv1 temporal kernel.
[[3]], # res2 temporal kernel.
[[3, 1]], # res3 temporal kernel.
[[3, 1]], # res4 temporal kernel.
[[1, 3]], # res5 temporal kernel.
],
"i3d_nopool": [
[[5]], # conv1 temporal kernel.
[[3]], # res2 temporal kernel.
[[3, 1]], # res3 temporal kernel.
[[3, 1]], # res4 temporal kernel.
[[1, 3]], # res5 temporal kernel.
],
"slow": [
[[1]], # conv1 temporal kernel.
[[1]], # res2 temporal kernel.
[[1]], # res3 temporal kernel.
[[3]], # res4 temporal kernel.
[[3]], # res5 temporal kernel.
],
"slowfast": [
[[1], [5]], # conv1 temporal kernel for slow and fast pathway.
[[1], [3]], # res2 temporal kernel for slow and fast pathway.
[[1], [3]], # res3 temporal kernel for slow and fast pathway.
[[3], [3]], # res4 temporal kernel for slow and fast pathway.
[[3], [3]], # res5 temporal kernel for slow and fast pathway.
],
}
_POOL1 = {
"c2d": [[2, 1, 1]],
"c2d_nopool": [[1, 1, 1]],
"i3d": [[2, 1, 1]],
"r3d_18": [[2, 1, 1]],
"i3d_nopool": [[1, 1, 1]],
"slow": [[1, 1, 1]],
"slowfast": [[1, 1, 1], [1, 1, 1]],
}
class FuseFastToSlow(nn.Module):
"""
Fuses the information from the Fast pathway to the Slow pathway. Given the
tensors from Slow pathway and Fast pathway, fuse information from Fast to
Slow, then return the fused tensors from Slow and Fast pathway in order.
"""
def __init__(
self,
dim_in,
fusion_conv_channel_ratio,
fusion_kernel,
alpha,
eps=1e-5,
bn_mmt=0.1,
inplace_relu=True,
norm_module=nn.BatchNorm3d,
):
"""
Args:
dim_in (int): the channel dimension of the input.
fusion_conv_channel_ratio (int): channel ratio for the convolution
used to fuse from Fast pathway to Slow pathway.
fusion_kernel (int): kernel size of the convolution used to fuse
from Fast pathway to Slow pathway.
alpha (int): the frame rate ratio between the Fast and Slow pathway.
eps (float): epsilon for batch norm.
bn_mmt (float): momentum for batch norm. Noted that BN momentum in
PyTorch = 1 - BN momentum in Caffe2.
inplace_relu (bool): if True, calculate the relu on the original
input without allocating new memory.
norm_module (nn.Module): nn.Module for the normalization layer. The
default is nn.BatchNorm3d.
"""
super(FuseFastToSlow, self).__init__()
self.conv_f2s = nn.Conv3d(
dim_in,
dim_in * fusion_conv_channel_ratio,
kernel_size=[fusion_kernel, 1, 1],
stride=[alpha, 1, 1],
padding=[fusion_kernel // 2, 0, 0],
bias=False,
)
self.bn = norm_module(
num_features=dim_in * fusion_conv_channel_ratio,
eps=eps,
momentum=bn_mmt,
)
self.relu = nn.ReLU(inplace_relu)
def forward(self, x):
x_s = x[0]
x_f = x[1]
fuse = self.conv_f2s(x_f)
fuse = self.bn(fuse)
fuse = self.relu(fuse)
x_s_fuse = torch.cat([x_s, fuse], 1)
return [x_s_fuse, x_f]
@MODEL_REGISTRY.register()
class SlowFast(nn.Module):
"""
SlowFast model builder for SlowFast network.
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(SlowFast, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.num_pathways = 2
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a SlowFast model. The first pathway is the Slow pathway and the
second pathway is the Fast pathway.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
out_dim_ratio = (
cfg.SLOWFAST.BETA_INV // cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO
)
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group, width_per_group // cfg.SLOWFAST.BETA_INV],
kernel=[temp_kernel[0][0] + [7, 7], temp_kernel[0][1] + [7, 7]],
stride=[[1, 2, 2]] * 2,
padding=[
[temp_kernel[0][0][0] // 2, 3, 3],
[temp_kernel[0][1][0] // 2, 3, 3],
],
norm_module=self.norm_module,
)
self.s1_fuse = FuseFastToSlow(
width_per_group // cfg.SLOWFAST.BETA_INV,
cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO,
cfg.SLOWFAST.FUSION_KERNEL_SZ,
cfg.SLOWFAST.ALPHA,
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[
width_per_group + width_per_group // out_dim_ratio,
width_per_group // cfg.SLOWFAST.BETA_INV,
],
dim_out=[
width_per_group * 4,
width_per_group * 4 // cfg.SLOWFAST.BETA_INV,
],
dim_inner=[dim_inner, dim_inner // cfg.SLOWFAST.BETA_INV],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2] * 2,
num_groups=[num_groups] * 2,
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
self.s2_fuse = FuseFastToSlow(
width_per_group * 4 // cfg.SLOWFAST.BETA_INV,
cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO,
cfg.SLOWFAST.FUSION_KERNEL_SZ,
cfg.SLOWFAST.ALPHA,
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[
width_per_group * 4 + width_per_group * 4 // out_dim_ratio,
width_per_group * 4 // cfg.SLOWFAST.BETA_INV,
],
dim_out=[
width_per_group * 8,
width_per_group * 8 // cfg.SLOWFAST.BETA_INV,
],
dim_inner=[dim_inner * 2, dim_inner * 2 // cfg.SLOWFAST.BETA_INV],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3] * 2,
num_groups=[num_groups] * 2,
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s3_fuse = FuseFastToSlow(
width_per_group * 8 // cfg.SLOWFAST.BETA_INV,
cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO,
cfg.SLOWFAST.FUSION_KERNEL_SZ,
cfg.SLOWFAST.ALPHA,
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[
width_per_group * 8 + width_per_group * 8 // out_dim_ratio,
width_per_group * 8 // cfg.SLOWFAST.BETA_INV,
],
dim_out=[
width_per_group * 16,
width_per_group * 16 // cfg.SLOWFAST.BETA_INV,
],
dim_inner=[dim_inner * 4, dim_inner * 4 // cfg.SLOWFAST.BETA_INV],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4] * 2,
num_groups=[num_groups] * 2,
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
self.s4_fuse = FuseFastToSlow(
width_per_group * 16 // cfg.SLOWFAST.BETA_INV,
cfg.SLOWFAST.FUSION_CONV_CHANNEL_RATIO,
cfg.SLOWFAST.FUSION_KERNEL_SZ,
cfg.SLOWFAST.ALPHA,
norm_module=self.norm_module,
)
self.s5 = resnet_helper.ResStage(
dim_in=[
width_per_group * 16 + width_per_group * 16 // out_dim_ratio,
width_per_group * 16 // cfg.SLOWFAST.BETA_INV,
],
dim_out=[
width_per_group * 32,
width_per_group * 32 // cfg.SLOWFAST.BETA_INV,
],
dim_inner=[dim_inner * 8, dim_inner * 8 // cfg.SLOWFAST.BETA_INV],
temp_kernel_sizes=temp_kernel[4],
stride=cfg.RESNET.SPATIAL_STRIDES[3],
num_blocks=[d5] * 2,
num_groups=[num_groups] * 2,
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
nonlocal_group=cfg.NONLOCAL.GROUP[3],
nonlocal_pool=cfg.NONLOCAL.POOL[3],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
norm_module=self.norm_module,
)
if cfg.DETECTION.ENABLE:
raise NotImplementedError
else:
self.head = head_helper.ResNetBasicHead(
dim_in=[
width_per_group * 32,
width_per_group * 32 // cfg.SLOWFAST.BETA_INV,
],
num_classes=cfg.MODEL.NUM_CLASSES,
pool_size=[None, None]
if cfg.MULTIGRID.SHORT_CYCLE
else [
[
cfg.DATA.NUM_FRAMES
// cfg.SLOWFAST.ALPHA
// pool_size[0][0],
cfg.DATA.CROP_SIZE // 32 // pool_size[0][1],
cfg.DATA.CROP_SIZE // 32 // pool_size[0][2],
],
[
cfg.DATA.NUM_FRAMES // pool_size[1][0],
cfg.DATA.CROP_SIZE // 32 // pool_size[1][1],
cfg.DATA.CROP_SIZE // 32 // pool_size[1][2],
],
], # None for AdaptiveAvgPool3d((1, 1, 1))
dropout_rate=cfg.MODEL.DROPOUT_RATE,
act_func=cfg.MODEL.HEAD_ACT,
)
def forward(self, x, bboxes=None):
x = self.s1(x)
x = self.s1_fuse(x)
x = self.s2(x)
x = self.s2_fuse(x)
for pathway in range(self.num_pathways):
pool = getattr(self, "pathway{}_pool".format(pathway))
x[pathway] = pool(x[pathway])
x = self.s3(x)
x = self.s3_fuse(x)
x = self.s4(x)
x = self.s4_fuse(x)
x = self.s5(x)
if self.enable_detection:
x = self.head(x, bboxes)
else:
x = self.head(x)
return x
@MODEL_REGISTRY.register()
class ResNet(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResNet, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.num_pathways = 1
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
print(dim_inner)
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
self.s5 = resnet_helper.ResStage(
dim_in=[width_per_group * 16],
dim_out=[width_per_group * 32],
dim_inner=[dim_inner * 8],
temp_kernel_sizes=temp_kernel[4],
stride=cfg.RESNET.SPATIAL_STRIDES[3],
num_blocks=[d5],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
nonlocal_group=cfg.NONLOCAL.GROUP[3],
nonlocal_pool=cfg.NONLOCAL.POOL[3],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
norm_module=self.norm_module,
)
if self.enable_detection:
raise NotImplementedError
else:
self.head = head_helper.ResNetBasicHead(
dim_in=[width_per_group * 32],
num_classes=cfg.MODEL.NUM_CLASSES,
pool_size=[None, None]
if cfg.MULTIGRID.SHORT_CYCLE
else [
[
cfg.DATA.NUM_FRAMES // pool_size[0][0],
cfg.DATA.CROP_SIZE // 32 // pool_size[0][1],
cfg.DATA.CROP_SIZE // 32 // pool_size[0][2],
]
], # None for AdaptiveAvgPool3d((1, 1, 1))
dropout_rate=cfg.MODEL.DROPOUT_RATE,
act_func=cfg.MODEL.HEAD_ACT,
)
def forward(self, x, return_feat=False, bboxes=None):
x = self.s1(x)
x = self.s2(x)
for pathway in range(self.num_pathways):
pool = getattr(self, "pathway{}_pool".format(pathway))
x[pathway] = pool(x[pathway])
x = self.s3(x)
x = self.s4(x)
feat = self.s5(x)
if return_feat:
return feat
if self.enable_detection:
x = self.head(feat, bboxes)
else:
x = self.head(feat)
return x
@MODEL_REGISTRY.register()
class ResNetVar(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResNetVar, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.num_pathways = 1
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
self.s5 = resnet_helper.ResStage(
dim_in=[width_per_group * 16],
dim_out=[width_per_group * 32],
dim_inner=[dim_inner * 8],
temp_kernel_sizes=temp_kernel[4],
stride=cfg.RESNET.SPATIAL_STRIDES[3],
num_blocks=[d5],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
nonlocal_group=cfg.NONLOCAL.GROUP[3],
nonlocal_pool=cfg.NONLOCAL.POOL[3],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
norm_module=self.norm_module,
)
if self.enable_detection:
raise NotImplementedError
else:
self.head = head_helper.ResNetBasicHead(
dim_in=[width_per_group * 32],
num_classes=cfg.MODEL.NUM_CLASSES,
pool_size=[None],
dropout_rate=cfg.MODEL.DROPOUT_RATE,
act_func=cfg.MODEL.HEAD_ACT,
)
def forward(self, x, bboxes=None):
x = self.s1(x)
x = self.s2(x)
for pathway in range(self.num_pathways):
pool = getattr(self, "pathway{}_pool".format(pathway))
x[pathway] = pool(x[pathway])
x = self.s3(x)
x = self.s4(x)
x = self.s5(x)
if self.enable_detection:
x = self.head(x, bboxes)
else:
x = self.head(x)
return x
@MODEL_REGISTRY.register()
class ResNetBase(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResNetBase, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.num_pathways = 1
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
self.s5 = resnet_helper.ResStage(
dim_in=[width_per_group * 16],
dim_out=[width_per_group * 32],
dim_inner=[dim_inner * 8],
temp_kernel_sizes=temp_kernel[4],
stride=cfg.RESNET.SPATIAL_STRIDES[3],
num_blocks=[d5],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
nonlocal_group=cfg.NONLOCAL.GROUP[3],
nonlocal_pool=cfg.NONLOCAL.POOL[3],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
norm_module=self.norm_module,
)
if self.enable_detection:
raise NotImplementedError
else:
self.head = head_helper.ResNetBasicHead(
dim_in=[width_per_group * 32],
num_classes=cfg.MODEL.NUM_CLASSES,
pool_size=[None, None]
if cfg.MULTIGRID.SHORT_CYCLE
else [
None
], # None for AdaptiveAvgPool3d((1, 1, 1))
dropout_rate=cfg.MODEL.DROPOUT_RATE,
act_func=cfg.MODEL.HEAD_ACT,
)
def forward(self, x, bboxes=None):
x = self.s1(x)
x = self.s2(x)
for pathway in range(self.num_pathways):
pool = getattr(self, "pathway{}_pool".format(pathway))
x[pathway] = pool(x[pathway])
x = self.s3(x)
x = self.s4(x)
x = self.s5(x)
if self.enable_detection:
x = self.head(x, bboxes)
else:
x = self.head(x)
return x
@MODEL_REGISTRY.register()
class ResNetFreeze(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResNetFreeze, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.num_pathways = 1
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
self.s5 = resnet_helper.ResStage(
dim_in=[width_per_group * 16],
dim_out=[width_per_group * 32],
dim_inner=[dim_inner * 8],
temp_kernel_sizes=temp_kernel[4],
stride=cfg.RESNET.SPATIAL_STRIDES[3],
num_blocks=[d5],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
nonlocal_group=cfg.NONLOCAL.GROUP[3],
nonlocal_pool=cfg.NONLOCAL.POOL[3],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
norm_module=self.norm_module,
)
if self.enable_detection:
raise NotImplementedError
else:
self.head = head_helper.ResNetBasicHead(
dim_in=[width_per_group * 32],
num_classes=cfg.MODEL.NUM_CLASSES,
pool_size=[None,None]
if cfg.MULTIGRID.SHORT_CYCLE
else [
None
], # None for AdaptiveAvgPool3d((1, 1, 1))
dropout_rate=cfg.MODEL.DROPOUT_RATE,
act_func=cfg.MODEL.HEAD_ACT,
)
def forward(self, x, freeze_backbone=False):
assert isinstance(freeze_backbone,bool)
x = self.s1(x)
x = self.s2(x)
# for pathway in range(self.num_pathways):
# pool = getattr(self, "pathway{}_pool".format(pathway))
# x[pathway] = pool(x[pathway])
x = self.s3(x)
x = self.s4(x)
x = self.s5(x)
if freeze_backbone:
x=[item.detach() for item in x]
x = self.head(x)
return x
import torch.nn.functional as F
from .unet_helper import DecoderBlock,LightDecoderBlock,ResDecoderBlock
@MODEL_REGISTRY.register()
class ResUNet(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResUNet, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.enable_jitter = cfg.JITTER.ENABLE
self.num_pathways = 1
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
self.image_size = cfg.DATA.TRAIN_CROP_SIZE
self.clip_size = cfg.DATA.NUM_FRAMES
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
self.cfg = cfg
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
# self.s5 = resnet_helper.ResStage(
# dim_in=[width_per_group * 16],
# dim_out=[width_per_group * 32],
# dim_inner=[dim_inner * 8],
# temp_kernel_sizes=temp_kernel[4],
# stride=cfg.RESNET.SPATIAL_STRIDES[3],
# num_blocks=[d5],
# num_groups=[num_groups],
# num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
# nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
# nonlocal_group=cfg.NONLOCAL.GROUP[3],
# nonlocal_pool=cfg.NONLOCAL.POOL[3],
# instantiation=cfg.NONLOCAL.INSTANTIATION,
# trans_func_name=cfg.RESNET.TRANS_FUNC,
# stride_1x1=cfg.RESNET.STRIDE_1X1,
# inplace_relu=cfg.RESNET.INPLACE_RELU,
# dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
# norm_module=self.norm_module,
# )
self.labels=["rotate","light"]
self.dual_define("t4",self.labels,DecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 8))
self.dual_define("t3",self.labels,DecoderBlock(width_per_group * 8,width_per_group * 4, 256))
self.dual_define("conv1x1",self.labels,nn.Sequential(
nn.Conv3d(width_per_group*4+width_per_group, 1, kernel_size=(1, 1, 1), stride=1, padding=0), nn.Sigmoid()
))
self.linear = nn.Sequential(nn.Linear(1, 1), nn.Sigmoid())
def forward_plus(self, x, y, net):
return [net(x)[0] + y[0]]
def dual_define(self,name,labels,net):
for label in labels:
self.add_module(f"{name}_{label}",copy.deepcopy(net))
def upsample(self, x, dims=["space"]):
ori_size = x[0].shape[2:5]
t, h, w = ori_size
if "space" in dims:
h = 2 * h
w = 2 * w
if "time" in dims:
t = 2 * t
size = (t, h, w)
return [F.interpolate(x[0], size)]
def concat(self,x,y):
return [torch.cat([x[0],y[0]],1)]
# @torchsnooper.snoop()
def forward(self, x, bboxes=None):
x1 = self.s1(x) # 1,64,8,56,56
x2 = self.s2(x1) # 1,256,8,56,56
x3 = self.s3(x2) # 1,512,8,28, 28
x = self.s4(x3) # 1,1024,8,14,14
x = self.upsample(x) # 1,1024, 8, 28, 28
x = self.concat(x3,x)# 1,1024+512, 8, 28, 28
x=[self.forward_branch(x,x1,x2,label) for label in self.labels]
x=torch.cat(x,1)
out = x.mean([3, 4]).view(-1, 1)*100
out = self.linear(out)
out = out.view(x.size(0), -1)
return x,out
def forward_branch(self,x,x1,x2,label):
t4=getattr(self,f"t4_{label}")
x = t4(x[0])# 1,512, 8, 28, 28
x = self.upsample([x]) # 1,512, 8, 56, 56
x = self.concat(x2,x)# 1,256+512, 8, 56, 56
t3= getattr(self,f"t3_{label}")
x = t3(x[0]) # 1,256, 8, 56, 56
x = self.concat(x1,[x]) # 1,320, 8, 56, 56
conv1x1=getattr(self,f"conv1x1_{label}")
x = conv1x1(x[0]) # 1,2,8,56,56
return x
@MODEL_REGISTRY.register()
class ResUNetLight(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResUNetLight, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.enable_jitter = cfg.JITTER.ENABLE
self.num_pathways = 1
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
self.image_size = cfg.DATA.TRAIN_CROP_SIZE
self.clip_size = cfg.DATA.NUM_FRAMES
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
self.cfg = cfg
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
# self.s5 = resnet_helper.ResStage(
# dim_in=[width_per_group * 16],
# dim_out=[width_per_group * 32],
# dim_inner=[dim_inner * 8],
# temp_kernel_sizes=temp_kernel[4],
# stride=cfg.RESNET.SPATIAL_STRIDES[3],
# num_blocks=[d5],
# num_groups=[num_groups],
# num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
# nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
# nonlocal_group=cfg.NONLOCAL.GROUP[3],
# nonlocal_pool=cfg.NONLOCAL.POOL[3],
# instantiation=cfg.NONLOCAL.INSTANTIATION,
# trans_func_name=cfg.RESNET.TRANS_FUNC,
# stride_1x1=cfg.RESNET.STRIDE_1X1,
# inplace_relu=cfg.RESNET.INPLACE_RELU,
# dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
# norm_module=self.norm_module,
# )
self.labels=["rotate","light"]
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4))
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128))
self.dual_define("conv1x1",self.labels,nn.Sequential(
nn.Conv3d(128+width_per_group, 1, kernel_size=(1, 1, 1), stride=1, padding=0), nn.Sigmoid()
))
self.linear = nn.Sequential(nn.Linear(1, 1), nn.Sigmoid())
def forward_plus(self, x, y, net):
return [net(x)[0] + y[0]]
def dual_define(self,name,labels,net):
for label in labels:
self.add_module(f"{name}_{label}",copy.deepcopy(net))
def upsample(self, x, dims=["space"]):
ori_size = x[0].shape[2:5]
t, h, w = ori_size
if "space" in dims:
h = 2 * h
w = 2 * w
if "time" in dims:
t = 2 * t
size = (t, h, w)
return [F.interpolate(x[0], size)]
def concat(self,x,y):
return [torch.cat([x[0],y[0]],1)]
def get_detach_var(self,x):
return [t.detach() for t in x]
# @torchsnooper.snoop()
def forward(self, x, freeze_backbone=False):
x1 = self.s1(x) # 1,64,8,56,56
x2 = self.s2(x1) # 1,256,8,56,56
x3 = self.s3(x2) # 1,512,8,28, 28
x = self.s4(x3) # 1,1024,8,14,14
assert isinstance(freeze_backbone,bool)
if freeze_backbone:
x=self.get_detach_var(x)
x1=self.get_detach_var(x1)
x2=self.get_detach_var(x2)
x3=self.get_detach_var(x3)
x = self.upsample(x) # 1,1024, 8, 28, 28
x = self.concat(x3,x)# 1,1024+512, 8, 28, 28
x=[self.forward_branch(x,x1,x2,label) for label in self.labels]
x=torch.cat(x,1)
out = x.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56
out = self.linear(out)
out = out.view(x.size(0), -1)
return x,out
def forward_branch(self,x,x1,x2,label):
t4=getattr(self,f"t4_{label}")
x = t4(x[0])# 1,256, 8, 28, 28
x = self.upsample([x]) # 1,256, 8, 56, 56
x = self.concat(x2,x)# 1,256+256, 8, 56, 56
t3= getattr(self,f"t3_{label}")
x = t3(x[0]) # 1,128, 8, 56, 56
x = self.concat(x1,[x]) # 1,192, 8, 56, 56
conv1x1=getattr(self,f"conv1x1_{label}")
x = conv1x1(x[0]) # 1,2,8,56,56
return x
@MODEL_REGISTRY.register()
class ResUNetLightFix(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResUNetLightFix, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.enable_jitter = cfg.JITTER.ENABLE
self.num_pathways = 1
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
self.image_size = cfg.DATA.TRAIN_CROP_SIZE
self.clip_size = cfg.DATA.NUM_FRAMES
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
self.cfg = cfg
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
# self.s5 = resnet_helper.ResStage(
# dim_in=[width_per_group * 16],
# dim_out=[width_per_group * 32],
# dim_inner=[dim_inner * 8],
# temp_kernel_sizes=temp_kernel[4],
# stride=cfg.RESNET.SPATIAL_STRIDES[3],
# num_blocks=[d5],
# num_groups=[num_groups],
# num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
# nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
# nonlocal_group=cfg.NONLOCAL.GROUP[3],
# nonlocal_pool=cfg.NONLOCAL.POOL[3],
# instantiation=cfg.NONLOCAL.INSTANTIATION,
# trans_func_name=cfg.RESNET.TRANS_FUNC,
# stride_1x1=cfg.RESNET.STRIDE_1X1,
# inplace_relu=cfg.RESNET.INPLACE_RELU,
# dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
# norm_module=self.norm_module,
# )
self.labels=["rotate","light","skip"]
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4))
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128))
self.dual_define("conv1x1",self.labels,nn.Sequential(
nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0),
))
self.linear = nn.Sequential(nn.Linear(1, 1))
def forward_plus(self, x, y, net):
return [net(x)[0] + y[0]]
def dual_define(self,name,labels,net):
for label in labels:
self.add_module(f"{name}_{label}",copy.deepcopy(net))
def upsample(self, x, dims=["space"]):
ori_size = x[0].shape[2:5]
t, h, w = ori_size
if "space" in dims:
h = 2 * h
w = 2 * w
if "time" in dims:
t = 2 * t
size = (t, h, w)
return [F.interpolate(x[0], size)]
def concat(self,x,y):
return [torch.cat([x[0],y[0]],1)]
def get_detach_var(self,x):
return [t.detach() for t in x]
# @torchsnooper.snoop()
def forward(self, x, freeze_backbone=False):
x1 = self.s1(x) # 1,64,8,56,56
x2 = self.s2(x1) # 1,256,8,56,56
x3 = self.s3(x2) # 1,512,8,28, 28
x = self.s4(x3) # 1,1024,8,14,14
assert isinstance(freeze_backbone,bool)
if freeze_backbone:
x=self.get_detach_var(x)
x1=self.get_detach_var(x1)
x2=self.get_detach_var(x2)
x3=self.get_detach_var(x3)
x = self.upsample(x) # 1,1024, 8, 28, 28
x = self.concat(x3,x)# 1,1024+512, 8, 28, 28
x=[self.forward_branch(x,x1,x2,label) for label in self.labels]
x=torch.cat(x,1)
x=torch.sigmoid(x)
out = x.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56
out = self.linear(out)
out = out.view(x.size(0), -1)
out = torch.sigmoid(out)
return x,out
def forward_branch(self,x,x1,x2,label):
t4=getattr(self,f"t4_{label}")
x = t4(x[0])# 1,256, 8, 28, 28
x = self.upsample([x]) # 1,256, 8, 56, 56
x = self.concat(x2,x)# 1,256+256, 8, 56, 56
t3= getattr(self,f"t3_{label}")
x = t3(x[0]) # 1,128, 8, 56, 56
x = self.concat(x1,[x]) # 1,192, 8, 56, 56
conv1x1=getattr(self,f"conv1x1_{label}")
x = conv1x1(x[0]) # 1,2,8,56,56
return x
@MODEL_REGISTRY.register()
class ResUNetContinus(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResUNetContinus, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.enable_jitter = cfg.JITTER.ENABLE
self.num_pathways = 1
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
self.image_size = cfg.DATA.TRAIN_CROP_SIZE
self.clip_size = cfg.DATA.NUM_FRAMES
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
self.cfg = cfg
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
# self.s5 = resnet_helper.ResStage(
# dim_in=[width_per_group * 16],
# dim_out=[width_per_group * 32],
# dim_inner=[dim_inner * 8],
# temp_kernel_sizes=temp_kernel[4],
# stride=cfg.RESNET.SPATIAL_STRIDES[3],
# num_blocks=[d5],
# num_groups=[num_groups],
# num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
# nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
# nonlocal_group=cfg.NONLOCAL.GROUP[3],
# nonlocal_pool=cfg.NONLOCAL.POOL[3],
# instantiation=cfg.NONLOCAL.INSTANTIATION,
# trans_func_name=cfg.RESNET.TRANS_FUNC,
# stride_1x1=cfg.RESNET.STRIDE_1X1,
# inplace_relu=cfg.RESNET.INPLACE_RELU,
# dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
# norm_module=self.norm_module,
# )
self.labels=["all"]
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4))
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128))
self.dual_define("conv1x1",self.labels,nn.Sequential(
nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0),
))
self.linear = nn.Sequential(nn.Linear(1, 1))
def forward_plus(self, x, y, net):
return [net(x)[0] + y[0]]
def dual_define(self,name,labels,net):
for label in labels:
self.add_module(f"{name}_{label}",copy.deepcopy(net))
def upsample(self, x, dims=["space"]):
ori_size = x[0].shape[2:5]
t, h, w = ori_size
if "space" in dims:
h = 2 * h
w = 2 * w
if "time" in dims:
t = 2 * t
size = (t, h, w)
return [F.interpolate(x[0], size)]
def concat(self,x,y):
return [torch.cat([x[0],y[0]],1)]
def get_detach_var(self,x):
return [t.detach() for t in x]
# @torchsnooper.snoop()
def forward(self, x, freeze_backbone=False):
x1 = self.s1(x) # 1,64,8,56,56
x2 = self.s2(x1) # 1,256,8,56,56
x3 = self.s3(x2) # 1,512,8,28, 28
x = self.s4(x3) # 1,1024,8,14,14
assert isinstance(freeze_backbone,bool)
if freeze_backbone:
x=self.get_detach_var(x)
x1=self.get_detach_var(x1)
x2=self.get_detach_var(x2)
x3=self.get_detach_var(x3)
x = self.upsample(x) # 1,1024, 8, 28, 28
x = self.concat(x3,x)# 1,1024+512, 8, 28, 28
x=[self.forward_branch(x,x1,x2,label) for label in self.labels]
x=torch.cat(x,1)
x=torch.sigmoid(x)
out = x.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56
out = self.linear(out)
out = out.view(x.size(0), -1)
out = torch.sigmoid(out)
return x,out
def forward_branch(self,x,x1,x2,label):
t4= getattr(self,f"t4_{label}")
x = t4(x[0])# 1,256, 8, 28, 28
x = self.upsample([x]) # 1,256, 8, 56, 56
x = self.concat(x2,x)# 1,256+256, 8, 56, 56
t3= getattr(self,f"t3_{label}")
x = t3(x[0]) # 1,128, 8, 56, 56
x = self.concat(x1,[x]) # 1,192, 8, 56, 56
conv1x1=getattr(self,f"conv1x1_{label}")
x = conv1x1(x[0]) # 1,2,8,56,56
return x
@MODEL_REGISTRY.register()
class ResUNetCommon(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResUNetCommon, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.enable_jitter = cfg.JITTER.ENABLE
self.num_pathways = 1
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
self.image_size = cfg.DATA.TRAIN_CROP_SIZE
self.clip_size = cfg.DATA.NUM_FRAMES
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
self.cfg = cfg
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
# self.s5 = resnet_helper.ResStage(
# dim_in=[width_per_group * 16],
# dim_out=[width_per_group * 32],
# dim_inner=[dim_inner * 8],
# temp_kernel_sizes=temp_kernel[4],
# stride=cfg.RESNET.SPATIAL_STRIDES[3],
# num_blocks=[d5],
# num_groups=[num_groups],
# num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
# nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
# nonlocal_group=cfg.NONLOCAL.GROUP[3],
# nonlocal_pool=cfg.NONLOCAL.POOL[3],
# instantiation=cfg.NONLOCAL.INSTANTIATION,
# trans_func_name=cfg.RESNET.TRANS_FUNC,
# stride_1x1=cfg.RESNET.STRIDE_1X1,
# inplace_relu=cfg.RESNET.INPLACE_RELU,
# dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
# norm_module=self.norm_module,
# )
self.labels=cfg.RESNET.LABELS
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4))
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128))
self.dual_define("conv1x1",self.labels,nn.Sequential(
nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0),
))
self.linear = nn.Linear(1, 2)
def forward_plus(self, x, y, net):
return [net(x)[0] + y[0]]
def dual_define(self,name,labels,net):
for label in labels:
self.add_module(f"{name}_{label}",copy.deepcopy(net))
def upsample(self, x, dims=["space"]):
ori_size = x[0].shape[2:5]
t, h, w = ori_size
if "space" in dims:
h = 2 * h
w = 2 * w
if "time" in dims:
t = 2 * t
size = (t, h, w)
return [F.interpolate(x[0], size)]
def concat(self,x,y):
return [torch.cat([x[0],y[0]],1)]
def get_detach_var(self,x):
return [t.detach() for t in x]
# @torchsnooper.snoop()
def forward(self, x, freeze_backbone=False):
x = self.get_detach_var(x)
x1 = self.s1(x) # 1,64,8,56,56
x2 = self.s2(x1) # 1,256,8,56,56
x3 = self.s3(x2) # 1,512,8,28, 28
feat= self.s4(x3) # 1,1024,8,14,14
assert isinstance(freeze_backbone,bool)
if freeze_backbone:
feat=self.get_detach_var(feat)
x1=self.get_detach_var(x1)
x2=self.get_detach_var(x2)
x3=self.get_detach_var(x3)
feat = self.upsample(feat) # 1,1024, 8, 28, 28
feat = self.concat(x3,feat)# 1,1024+512, 8, 28, 28
reg_out=[self.forward_branch(feat,x1,x2,label) for label in self.labels]
reg_out=torch.cat(reg_out,1)
reg_out=torch.sigmoid(reg_out)
class_out = reg_out.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56
class_out = self.linear(class_out)
class_out = class_out.view(reg_out.size(0),len(self.labels),-1)
class_out = class_out
return reg_out,class_out
def forward_branch(self,feat,x1,x2,label):
t4= getattr(self,f"t4_{label}")
feat = t4(feat[0])# 1,256, 8, 28, 28
feat = self.upsample([feat]) # 1,256, 8, 56, 56
feat = self.concat(x2,feat)# 1,256+256, 8, 56, 56
t3= getattr(self,f"t3_{label}")
feat = t3(feat[0]) # 1,128, 8, 56, 56
feat = self.concat(x1,[feat]) # 1,192, 8, 56, 56
conv1x1=getattr(self,f"conv1x1_{label}")
feat = conv1x1(feat[0]) # 1,2,8,56,56
return feat
@MODEL_REGISTRY.register()
class ResUNetCommon2(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResUNetCommon2, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.enable_jitter = cfg.JITTER.ENABLE
self.num_pathways = 1
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
self.image_size = cfg.DATA.TRAIN_CROP_SIZE
self.clip_size = cfg.DATA.NUM_FRAMES
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
self.cfg = cfg
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
# self.s5 = resnet_helper.ResStage(
# dim_in=[width_per_group * 16],
# dim_out=[width_per_group * 32],
# dim_inner=[dim_inner * 8],
# temp_kernel_sizes=temp_kernel[4],
# stride=cfg.RESNET.SPATIAL_STRIDES[3],
# num_blocks=[d5],
# num_groups=[num_groups],
# num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
# nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
# nonlocal_group=cfg.NONLOCAL.GROUP[3],
# nonlocal_pool=cfg.NONLOCAL.POOL[3],
# instantiation=cfg.NONLOCAL.INSTANTIATION,
# trans_func_name=cfg.RESNET.TRANS_FUNC,
# stride_1x1=cfg.RESNET.STRIDE_1X1,
# inplace_relu=cfg.RESNET.INPLACE_RELU,
# dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
# norm_module=self.norm_module,
# )
self.labels=cfg.RESNET.LABELS
self.dual_define("t4",self.labels,LightDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 4))
self.dual_define("t3",self.labels,LightDecoderBlock(width_per_group * 4,width_per_group * 4, 128))
self.dual_define("conv1x1",self.labels,nn.Sequential(
nn.Conv3d(128+width_per_group, 64, kernel_size=(1, 1, 1), stride=1, padding=0),
nn.BatchNorm3d(64),
nn.ReLU(),
nn.Conv3d(64, 1, kernel_size=(1, 1, 1), stride=1, padding=0),
))
self.linear = nn.Linear(1, 1)
def forward_plus(self, x, y, net):
return [net(x)[0] + y[0]]
def dual_define(self,name,labels,net):
for label in labels:
self.add_module(f"{name}_{label}",copy.deepcopy(net))
def upsample(self, x, dims=["space"]):
ori_size = x[0].shape[2:5]
t, h, w = ori_size
if "space" in dims:
h = 2 * h
w = 2 * w
if "time" in dims:
t = 2 * t
size = (t, h, w)
return [F.interpolate(x[0], size)]
def concat(self,x,y):
return [torch.cat([x[0],y[0]],1)]
def get_detach_var(self,x):
return [t.detach() for t in x]
# @torchsnooper.snoop()
def forward(self, x, freeze_backbone=False):
x = self.get_detach_var(x)
x1 = self.s1(x) # 1,64,8,56,56
x2 = self.s2(x1) # 1,256,8,56,56
x3 = self.s3(x2) # 1,512,8,28, 28
feat= self.s4(x3) # 1,1024,8,14,14
assert isinstance(freeze_backbone,bool)
if freeze_backbone:
feat=self.get_detach_var(feat)
x1=self.get_detach_var(x1)
x2=self.get_detach_var(x2)
x3=self.get_detach_var(x3)
feat = self.upsample(feat) # 1,1024, 8, 28, 28
feat = self.concat(x3,feat)# 1,1024+512, 8, 28, 28
reg_out=[self.forward_branch(feat,x1,x2,label) for label in self.labels]
reg_out=torch.cat(reg_out,1)
reg_out=torch.sigmoid(reg_out)
class_out = reg_out.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56
class_out = self.linear(class_out)
class_out = class_out.view(reg_out.size(0),len(self.labels),-1)
class_out = torch.sigmoid(class_out)
return reg_out,class_out
def forward_branch(self,feat,x1,x2,label):
t4= getattr(self,f"t4_{label}")
feat = t4(feat[0])# 1,256, 8, 28, 28
feat = self.upsample([feat]) # 1,256, 8, 56, 56
feat = self.concat(x2,feat)# 1,256+256, 8, 56, 56
t3= getattr(self,f"t3_{label}")
feat = t3(feat[0]) # 1,128, 8, 56, 56
feat = self.concat(x1,[feat]) # 1,192, 8, 56, 56
conv1x1=getattr(self,f"conv1x1_{label}")
feat = conv1x1(feat[0]) # 1,2,8,56,56
return feat
@MODEL_REGISTRY.register()
class ResUNetStrong(nn.Module):
"""
ResNet model builder. It builds a ResNet like network backbone without
lateral connection (C2D, I3D, Slow).
Christoph Feichtenhofer, Haoqi Fan, Jitendra Malik, and Kaiming He.
"SlowFast networks for video recognition."
https://arxiv.org/pdf/1812.03982.pdf
Xiaolong Wang, Ross Girshick, Abhinav Gupta, and Kaiming He.
"Non-local neural networks."
https://arxiv.org/pdf/1711.07971.pdf
"""
def __init__(self, cfg):
"""
The `__init__` method of any subclass should also contain these
arguments.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
super(ResUNetStrong, self).__init__()
self.norm_module = get_norm(cfg)
self.enable_detection = cfg.DETECTION.ENABLE
self.enable_jitter = cfg.JITTER.ENABLE
self.num_pathways = 1
assert cfg.DATA.TRAIN_CROP_SIZE == cfg.DATA.TEST_CROP_SIZE
self.image_size = cfg.DATA.TRAIN_CROP_SIZE
self.clip_size = cfg.DATA.NUM_FRAMES
self._construct_network(cfg)
init_helper.init_weights(
self, cfg.MODEL.FC_INIT_STD, cfg.RESNET.ZERO_INIT_FINAL_BN
)
def _construct_network(self, cfg):
"""
Builds a single pathway ResNet model.
Args:
cfg (CfgNode): model building configs, details are in the
comments of the config file.
"""
assert cfg.MODEL.ARCH in _POOL1.keys()
pool_size = _POOL1[cfg.MODEL.ARCH]
self.cfg = cfg
assert len({len(pool_size), self.num_pathways}) == 1
assert cfg.RESNET.DEPTH in _MODEL_STAGE_DEPTH.keys()
(d2, d3, d4, d5) = _MODEL_STAGE_DEPTH[cfg.RESNET.DEPTH]
num_groups = cfg.RESNET.NUM_GROUPS
width_per_group = cfg.RESNET.WIDTH_PER_GROUP
dim_inner = num_groups * width_per_group
temp_kernel = _TEMPORAL_KERNEL_BASIS[cfg.MODEL.ARCH]
self.s1 = stem_helper.VideoModelStem(
dim_in=cfg.DATA.INPUT_CHANNEL_NUM,
dim_out=[width_per_group],
kernel=[temp_kernel[0][0] + [7, 7]],
stride=[[1, 2, 2]],
padding=[[temp_kernel[0][0][0] // 2, 3, 3]],
norm_module=self.norm_module,
)
self.s2 = resnet_helper.ResStage(
dim_in=[width_per_group],
dim_out=[width_per_group * 4],
dim_inner=[dim_inner],
temp_kernel_sizes=temp_kernel[1],
stride=cfg.RESNET.SPATIAL_STRIDES[0],
num_blocks=[d2],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[0],
nonlocal_inds=cfg.NONLOCAL.LOCATION[0],
nonlocal_group=cfg.NONLOCAL.GROUP[0],
nonlocal_pool=cfg.NONLOCAL.POOL[0],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[0],
norm_module=self.norm_module,
)
for pathway in range(self.num_pathways):
pool = nn.MaxPool3d(
kernel_size=pool_size[pathway],
stride=pool_size[pathway],
padding=[0, 0, 0],
)
self.add_module("pathway{}_pool".format(pathway), pool)
self.s3 = resnet_helper.ResStage(
dim_in=[width_per_group * 4],
dim_out=[width_per_group * 8],
dim_inner=[dim_inner * 2],
temp_kernel_sizes=temp_kernel[2],
stride=cfg.RESNET.SPATIAL_STRIDES[1],
num_blocks=[d3],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[1],
nonlocal_inds=cfg.NONLOCAL.LOCATION[1],
nonlocal_group=cfg.NONLOCAL.GROUP[1],
nonlocal_pool=cfg.NONLOCAL.POOL[1],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[1],
norm_module=self.norm_module,
)
self.s4 = resnet_helper.ResStage(
dim_in=[width_per_group * 8],
dim_out=[width_per_group * 16],
dim_inner=[dim_inner * 4],
temp_kernel_sizes=temp_kernel[3],
stride=cfg.RESNET.SPATIAL_STRIDES[2],
num_blocks=[d4],
num_groups=[num_groups],
num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[2],
nonlocal_inds=cfg.NONLOCAL.LOCATION[2],
nonlocal_group=cfg.NONLOCAL.GROUP[2],
nonlocal_pool=cfg.NONLOCAL.POOL[2],
instantiation=cfg.NONLOCAL.INSTANTIATION,
trans_func_name=cfg.RESNET.TRANS_FUNC,
stride_1x1=cfg.RESNET.STRIDE_1X1,
inplace_relu=cfg.RESNET.INPLACE_RELU,
dilation=cfg.RESNET.SPATIAL_DILATIONS[2],
norm_module=self.norm_module,
)
# self.s5 = resnet_helper.ResStage(
# dim_in=[width_per_group * 16],
# dim_out=[width_per_group * 32],
# dim_inner=[dim_inner * 8],
# temp_kernel_sizes=temp_kernel[4],
# stride=cfg.RESNET.SPATIAL_STRIDES[3],
# num_blocks=[d5],
# num_groups=[num_groups],
# num_block_temp_kernel=cfg.RESNET.NUM_BLOCK_TEMP_KERNEL[3],
# nonlocal_inds=cfg.NONLOCAL.LOCATION[3],
# nonlocal_group=cfg.NONLOCAL.GROUP[3],
# nonlocal_pool=cfg.NONLOCAL.POOL[3],
# instantiation=cfg.NONLOCAL.INSTANTIATION,
# trans_func_name=cfg.RESNET.TRANS_FUNC,
# stride_1x1=cfg.RESNET.STRIDE_1X1,
# inplace_relu=cfg.RESNET.INPLACE_RELU,
# dilation=cfg.RESNET.SPATIAL_DILATIONS[3],
# norm_module=self.norm_module,
# )
self.labels=cfg.RESNET.LABELS
self.dual_define("t4",self.labels,ResDecoderBlock(width_per_group * 16,width_per_group * 8,width_per_group * 8))
self.dual_define("t3",self.labels,ResDecoderBlock(width_per_group * 8,width_per_group * 4, 256))
self.dual_define("conv1x1",self.labels,nn.Sequential(
nn.Conv3d(width_per_group*4+width_per_group, 128, kernel_size=(1, 1, 1), stride=1, padding=0),
nn.BatchNorm3d(128),
nn.ReLU(),
nn.Conv3d(128, 1, kernel_size=(1, 1, 1), stride=1, padding=0),
))
self.linear = nn.Linear(1, 1)
def forward_plus(self, x, y, net):
return [net(x)[0] + y[0]]
def dual_define(self,name,labels,net):
for label in labels:
self.add_module(f"{name}_{label}",copy.deepcopy(net))
def upsample(self, x, dims=["space"]):
ori_size = x[0].shape[2:5]
t, h, w = ori_size
if "space" in dims:
h = 2 * h
w = 2 * w
if "time" in dims:
t = 2 * t
size = (t, h, w)
return [F.interpolate(x[0], size)]
def concat(self,x,y):
return [torch.cat([x[0],y[0]],1)]
def get_detach_var(self,x):
return [t.detach() for t in x]
# @torchsnooper.snoop()
def forward(self, x, freeze_backbone=False):
x = self.get_detach_var(x)
x1 = self.s1(x) # 1,64,8,56,56
x2 = self.s2(x1) # 1,256,8,56,56
x3 = self.s3(x2) # 1,512,8,28, 28
feat= self.s4(x3) # 1,1024,8,14,14
assert isinstance(freeze_backbone,bool)
if freeze_backbone:
feat=self.get_detach_var(feat)
x1=self.get_detach_var(x1)
x2=self.get_detach_var(x2)
x3=self.get_detach_var(x3)
feat = self.upsample(feat) # 1,1024, 8, 28, 28
feat = self.concat(x3,feat)# 1,1024+512, 8, 28, 28
reg_out=[self.forward_branch(feat,x1,x2,label) for label in self.labels]
reg_out=torch.cat(reg_out,1)
reg_out=torch.sigmoid(reg_out)
class_out = reg_out.mean([3, 4]).view(-1, 1)*100 # 1,2,8,56,56
class_out = self.linear(class_out)
class_out = class_out.view(reg_out.size(0),len(self.labels),-1)
class_out = torch.sigmoid(class_out)
return reg_out,class_out
def forward_branch(self,feat,x1,x2,label):
t4= getattr(self,f"t4_{label}")
feat = t4(feat[0])# 1,256, 8, 28, 28
feat = self.upsample([feat]) # 1,256, 8, 56, 56
feat = self.concat(x2,feat)# 1,256+256, 8, 56, 56
t3= getattr(self,f"t3_{label}")
feat = t3(feat[0]) # 1,128, 8, 56, 56
feat = self.concat(x1,[feat]) # 1,192, 8, 56, 56
conv1x1=getattr(self,f"conv1x1_{label}")
feat = conv1x1(feat[0]) # 1,2,8,56,56
return feat