ianpan's picture
Initial commit
231edce
import re
import timm
import torch
from functools import partial
from timm.models.vision_transformer import VisionTransformer
from timm.models.swin_transformer_v2 import SwinTransformerV2
from .vmz.backbones import *
def check_name(name, s):
return bool(re.search(s, name))
def create_backbone(name, pretrained, features_only=False, **kwargs):
try:
model = timm.create_model(name, pretrained=pretrained,
features_only=features_only,
num_classes=0, global_pool="")
except Exception as e:
assert name in BACKBONES, f"{name} is not a valid backbone"
model = BACKBONES[name](pretrained=pretrained, features_only=features_only, **kwargs)
with torch.no_grad():
if check_name(name, r"x3d|csn|r2plus1d|i3d"):
dim_feats = model(torch.randn((2, 3, 64, 64, 64))).size(1)
elif isinstance(model, (VisionTransformer, SwinTransformerV2)):
dim_feats = model.norm.normalized_shape[0]
else:
dim_feats = model(torch.randn((2, 3, 128, 128))).size(1)
return model, dim_feats
def create_csn(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs):
if features_only:
raise Exception("features_only is currently not supported")
if not pretrained:
from pytorchvideo.models import hub
model = getattr(hub, name)(pretrained=False)
else:
model = torch.hub.load("facebookresearch/pytorchvideo:main", model=name, pretrained=pretrained)
model.blocks[5] = nn.Identity()
return model
def create_x3d(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs):
if not pretrained:
from pytorchvideo.models import hub
model = getattr(hub, name)(pretrained=False)
else:
model = torch.hub.load("facebookresearch/pytorchvideo", model=name, pretrained=pretrained)
for idx, z in enumerate(z_strides):
assert z in [1, 2], "Only z-strides of 1 or 2 are supported"
if z == 2:
if idx == 0:
stem_layer = model.blocks[0].conv.conv_t
w = stem_layer.weight
w = w.repeat(1, 1, 3, 1, 1)
in_channels, out_channels = stem_layer.in_channels, stem_layer.out_channels
model.blocks[0].conv.conv_t = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1))
else:
model.blocks[idx].res_blocks[0].branch1_conv.stride = (2, 2, 2)
model.blocks[idx].res_blocks[0].branch2.conv_b.stride = (2, 2, 2)
if features_only:
model.blocks[-1] = nn.Identity()
model = X3D_Features(model)
else:
model.blocks[-1] = nn.Sequential(
model.blocks[-1].pool.pre_conv,
model.blocks[-1].pool.pre_norm,
model.blocks[-1].pool.pre_act,
)
return model
def create_i3d(name, pretrained, features_only=False, **kwargs):
from pytorchvideo.models import hub
model = getattr(hub, name)(pretrained=pretrained)
model.blocks[-1] = nn.Identity()
return model
class X3D_Features(nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.out_channels = [24, 24, 48, 96, 192]
def forward(self, x):
features = []
for idx in range(len(self.model.blocks) - 1):
x = self.model.blocks[idx](x)
features.append(x)
return features
BACKBONES = {
"x3d_xs": partial(create_x3d, name="x3d_xs"),
"x3d_s": partial(create_x3d, name="x3d_s"),
"x3d_m": partial(create_x3d, name="x3d_m"),
"x3d_l": partial(create_x3d, name="x3d_l"),
"i3d_r50": partial(create_i3d, name="i3d_r50"),
"csn_r101": partial(create_csn, name="csn_r101"),
"ir_csn_50": ir_csn_50,
"ir_csn_101": ir_csn_101,
"ir_csn_152": ir_csn_152,
"ip_csn_50": ip_csn_50,
"ip_csn_101": ip_csn_101,
"ip_csn_152": ip_csn_152,
"r2plus1d_34": r2plus1d_34
}