Spaces:
Runtime error
Runtime error
import re | |
import timm | |
import torch | |
from functools import partial | |
from timm.models.vision_transformer import VisionTransformer | |
from timm.models.swin_transformer_v2 import SwinTransformerV2 | |
from .vmz.backbones import * | |
def check_name(name, s): | |
return bool(re.search(s, name)) | |
def create_backbone(name, pretrained, features_only=False, **kwargs): | |
try: | |
model = timm.create_model(name, pretrained=pretrained, | |
features_only=features_only, | |
num_classes=0, global_pool="") | |
except Exception as e: | |
assert name in BACKBONES, f"{name} is not a valid backbone" | |
model = BACKBONES[name](pretrained=pretrained, features_only=features_only, **kwargs) | |
with torch.no_grad(): | |
if check_name(name, r"x3d|csn|r2plus1d|i3d"): | |
dim_feats = model(torch.randn((2, 3, 64, 64, 64))).size(1) | |
elif isinstance(model, (VisionTransformer, SwinTransformerV2)): | |
dim_feats = model.norm.normalized_shape[0] | |
else: | |
dim_feats = model(torch.randn((2, 3, 128, 128))).size(1) | |
return model, dim_feats | |
def create_csn(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs): | |
if features_only: | |
raise Exception("features_only is currently not supported") | |
if not pretrained: | |
from pytorchvideo.models import hub | |
model = getattr(hub, name)(pretrained=False) | |
else: | |
model = torch.hub.load("facebookresearch/pytorchvideo:main", model=name, pretrained=pretrained) | |
model.blocks[5] = nn.Identity() | |
return model | |
def create_x3d(name, pretrained, features_only=False, z_strides=[1, 1, 1, 1, 1], **kwargs): | |
if not pretrained: | |
from pytorchvideo.models import hub | |
model = getattr(hub, name)(pretrained=False) | |
else: | |
model = torch.hub.load("facebookresearch/pytorchvideo", model=name, pretrained=pretrained) | |
for idx, z in enumerate(z_strides): | |
assert z in [1, 2], "Only z-strides of 1 or 2 are supported" | |
if z == 2: | |
if idx == 0: | |
stem_layer = model.blocks[0].conv.conv_t | |
w = stem_layer.weight | |
w = w.repeat(1, 1, 3, 1, 1) | |
in_channels, out_channels = stem_layer.in_channels, stem_layer.out_channels | |
model.blocks[0].conv.conv_t = nn.Conv3d(in_channels, out_channels, kernel_size=(3, 3, 3), stride=(2, 2, 2), padding=(1, 1, 1)) | |
else: | |
model.blocks[idx].res_blocks[0].branch1_conv.stride = (2, 2, 2) | |
model.blocks[idx].res_blocks[0].branch2.conv_b.stride = (2, 2, 2) | |
if features_only: | |
model.blocks[-1] = nn.Identity() | |
model = X3D_Features(model) | |
else: | |
model.blocks[-1] = nn.Sequential( | |
model.blocks[-1].pool.pre_conv, | |
model.blocks[-1].pool.pre_norm, | |
model.blocks[-1].pool.pre_act, | |
) | |
return model | |
def create_i3d(name, pretrained, features_only=False, **kwargs): | |
from pytorchvideo.models import hub | |
model = getattr(hub, name)(pretrained=pretrained) | |
model.blocks[-1] = nn.Identity() | |
return model | |
class X3D_Features(nn.Module): | |
def __init__(self, model): | |
super().__init__() | |
self.model = model | |
self.out_channels = [24, 24, 48, 96, 192] | |
def forward(self, x): | |
features = [] | |
for idx in range(len(self.model.blocks) - 1): | |
x = self.model.blocks[idx](x) | |
features.append(x) | |
return features | |
BACKBONES = { | |
"x3d_xs": partial(create_x3d, name="x3d_xs"), | |
"x3d_s": partial(create_x3d, name="x3d_s"), | |
"x3d_m": partial(create_x3d, name="x3d_m"), | |
"x3d_l": partial(create_x3d, name="x3d_l"), | |
"i3d_r50": partial(create_i3d, name="i3d_r50"), | |
"csn_r101": partial(create_csn, name="csn_r101"), | |
"ir_csn_50": ir_csn_50, | |
"ir_csn_101": ir_csn_101, | |
"ir_csn_152": ir_csn_152, | |
"ip_csn_50": ip_csn_50, | |
"ip_csn_101": ip_csn_101, | |
"ip_csn_152": ip_csn_152, | |
"r2plus1d_34": r2plus1d_34 | |
} |