import re import timm import torch from functools import partial from timm.models.vision_transformer import VisionTransformer from timm.models.swin_transformer_v2 import SwinTransformerV2 from .vmz.backbones import * def check_name(name, s): return bool(re.search(s, name)) def create_backbone(name, pretrained, features_only=False, **kwargs): try: model = timm.create_model(name, pretrained=pretrained, features_only=features_only, num_classes=0, global_pool="") except Exception as e: assert name in BACKBONES, f"{name} is not a valid backbone" model = BACKBONES[name](pretrained=pretrained, features_only=features_only, **kwargs) with torch.no_grad(): if check_name(name, r"x3d|csn|r2plus1d|i3d"): dim_feats = model(torch.randn((2, 3, 64, 64, 64))).size(1) elif isinstance(model, (VisionTransformer, SwinTransformerV2)): dim_feats = model.norm.normalized_shape[0] else: dim_feats = model(torch.randn((2, 3, 128, 128))).size(1) return model, dim_feats def create_csn(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs): if features_only: raise Exception("features_only is currently not supported") if not pretrained: from pytorchvideo.models import hub model = getattr(hub, name)(pretrained=False) else: model = torch.hub.load("facebookresearch/pytorchvideo:main", model=name, pretrained=pretrained) model.blocks[5] = nn.Identity() return model def create_x3d(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs): if not pretrained: from pytorchvideo.models import hub model = getattr(hub, name)(pretrained=False) else: model = torch.hub.load("facebookresearch/pytorchvideo", model=name, pretrained=pretrained) for idx, z in enumerate(z_strides): assert z in [1, 2], "Only z-strides of 1 or 2 are supported" if z == 2: if idx == 0: stem_layer = model.blocks[0].conv.conv_t w = stem_layer.weight w = w.repeat(1, 1, 3, 1, 1) in_channels, out_channels = stem_layer.in_channels, stem_layer.out_channels model.blocks[0].conv.conv_t = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) else: model.blocks[idx].res_blocks[0].branch1_conv.stride = (2, 2, 2) model.blocks[idx].res_blocks[0].branch2.conv_b.stride = (2, 2, 2) if features_only: model.blocks[-1] = nn.Identity() model = X3D_Features(model) else: model.blocks[-1] = nn.Sequential( model.blocks[-1].pool.pre_conv, model.blocks[-1].pool.pre_norm, model.blocks[-1].pool.pre_act, ) return model def create_i3d(name, pretrained, features_only=False, **kwargs): from pytorchvideo.models import hub model = getattr(hub, name)(pretrained=pretrained) model.blocks[-1] = nn.Identity() return model class X3D_Features(nn.Module): def __init__(self, model): super().__init__() self.model = model self.out_channels = [24, 24, 48, 96, 192] def forward(self, x): features = [] for idx in range(len(self.model.blocks) - 1): x = self.model.blocks[idx](x) features.append(x) return features BACKBONES = { "x3d_xs": partial(create_x3d, name="x3d_xs"), "x3d_s": partial(create_x3d, name="x3d_s"), "x3d_m": partial(create_x3d, name="x3d_m"), "x3d_l": partial(create_x3d, name="x3d_l"), "i3d_r50": partial(create_i3d, name="i3d_r50"), "csn_r101": partial(create_csn, name="csn_r101"), "ir_csn_50": ir_csn_50, "ir_csn_101": ir_csn_101, "ir_csn_152": ir_csn_152, "ip_csn_50": ip_csn_50, "ip_csn_101": ip_csn_101, "ip_csn_152": ip_csn_152, "r2plus1d_34": r2plus1d_34 }