Spaces:
Runtime error
Runtime error
import re | |
import timm | |
import torch | |
import torch.nn as nn | |
from ...backbones import create_x3d | |
from ...tools import change_num_input_channels | |
from .swin_encoder import SwinTransformer | |
def get_attribute(model, name): | |
"""Hacked together function to retrieve the desired module from the model | |
based on its string attribute name. But it works. | |
""" | |
name = name.split(".") | |
for i, n in enumerate(name): | |
if i == 0: | |
if isinstance(n, int): | |
attr = model[n] | |
else: | |
attr = getattr(model, n) | |
else: | |
if isinstance(n, int): | |
attr = attr[n] | |
else: | |
attr = getattr(attr, n) | |
return attr | |
def check_if_int(s): | |
try: | |
_ = int(s) | |
return True | |
except ValueError: | |
return False | |
def create_encoder(name, encoder_params, encoder_output_stride=32, in_channels=3): | |
assert "pretrained" in encoder_params | |
if name == "swin": | |
assert encoder_output_stride == 32, "`swin` encoders only support output_stride=32" | |
encoder = SwinTransformer(**encoder_params) | |
elif "x3d" in name: | |
encoder = create_x3d(name, features_only=True, **encoder_params) | |
assert encoder_output_stride in [16, 32] | |
if encoder_output_stride == 16: | |
encoder.model.blocks[-2].res_blocks[0].branch1_conv.stride = (1, 1, 1) | |
encoder.model.blocks[-2].res_blocks[0].branch2.conv_b.stride = (1, 1, 1) | |
else: | |
encoder = timm.create_model(name, features_only=True, **encoder_params) | |
encoder.out_channels = encoder.feature_info.channels() | |
if encoder_output_stride != 32: | |
# Default for pretty much every model is 32 | |
# First, ensure that the provided stride is valid | |
assert 32 % encoder_output_stride == 0 | |
scale_factor = 32 // encoder_output_stride | |
layers_to_modify = 1 if scale_factor == 2 else 2 | |
# First, get the layers with stride 2 | |
# For some models, there may be other conv layers with stride 2 | |
# that will need to be filtered out | |
# EfficientNet is OK | |
if re.search(r"resnest", name): | |
if encoder_output_stride in [8, 16]: | |
encoder.layer4[0].downsample[0] = nn.Identity() | |
encoder.layer4[0].avd_last = nn.Identity() | |
if encoder_output_stride == 8: | |
encoder.layer3[0].downsample[0] = nn.Identity() | |
encoder.layer3[0].avd_last = nn.Identity() | |
else: | |
raise Exception(f"{name} only supports output stride of 8, 16, or 32") | |
elif re.search(r"resnet[0-9]+d", name): | |
if encoder_output_stride in [8, 16]: | |
encoder.layer4[0].downsample[0] = nn.Identity() | |
encoder.layer4[0].conv1.stride = (1, 1) | |
encoder.layer4[0].conv2.stride = (1, 1) | |
if encoder_output_stride == 8: | |
encoder.layer3[0].downsample[0] = nn.Identity() | |
encoder.layer3[0].conv1.stride = (1, 1) | |
encoder.layer3[0].conv2.stride = (1, 1) | |
else: | |
raise Exception(f"{name} only supports output stride of 8, 16, or 32") | |
elif re.search(r"regnet[x|y]", name): | |
downsample_convs = [] | |
for name, module in encoder.named_modules(): | |
if hasattr(module, "stride"): | |
if module.stride == (2, 2): | |
downsample_convs += [name] | |
downsample_convs = downsample_convs[::-1] | |
for i in range(layers_to_modify * 2): | |
setattr(get_attribute(encoder, downsample_convs[i]), "stride", (1, 1)) | |
elif re.search(r"efficientnet|regnetz|rexnet", name): | |
downsample_convs = [] | |
for name, module in encoder.named_modules(): | |
if hasattr(module, "stride"): | |
if module.stride == (2, 2): | |
downsample_convs += [name] | |
downsample_convs = downsample_convs[::-1] | |
for i in range(layers_to_modify): | |
setattr(get_attribute(encoder, downsample_convs[i]), "stride", (1, 1)) | |
elif re.search(r"convnext", name): | |
downsample_convs = [] | |
for name, module in encoder.named_modules(): | |
if hasattr(module, "stride"): | |
if module.stride == (2, 2): | |
downsample_convs += [name] | |
downsample_convs = downsample_convs[::-1] | |
for i in range(layers_to_modify): | |
setattr(get_attribute(encoder, downsample_convs[i]), "stride", (1, 1)) | |
# Need to also change the kernel size ... | |
# This involves creating a new layer with the appropriate kernel size | |
# Then modifying the weights to fit the new kernel size | |
# Then changing the layer in the model | |
in_channels = get_attribute(encoder, downsample_convs[i]).in_channels | |
out_channels = get_attribute(encoder, downsample_convs[i]).out_channels | |
conv_layer = nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=1) | |
w = get_attribute(encoder, downsample_convs[i]).weight | |
w = w.mean(-1, keepdim=True).mean(-2, keepdim=True) | |
conv_layer.weight = nn.Parameter(w) | |
split_name = downsample_convs[i].split(".") | |
if check_if_int(split_name[-1]): | |
# If the module name ends with a number that means it's within a sequential object | |
# and needs to be modified by accessing the module within a list. | |
# | |
# So you have to get the SEQUENTIAL object (by getting the attribute WITHOUT the number | |
# at the end) and then use that number as the list index and set the layer | |
# to that layer. Phew. | |
get_attribute(encoder, ".".join(split_name[:-1]))[int(split_name[-1])] = conv_layer | |
else: | |
# If the module name ends with a string that means it can be accessed by | |
# just grabbing the attribute | |
setattr(get_attribute(encoder, ".".join(split_name[:-1])), split_name[-1], conv_layer) | |
else: | |
raise Exception (f"{name} is not yet supported for output stride < 32") | |
# Run a quick test to make sure the output stride is correct | |
if "x3d" in name: | |
x = torch.randn((2,3,64,64,64)) | |
else: | |
x = torch.randn((2,3,128,128)) | |
final_out = encoder(x)[-1] | |
actual_output_stride = x.size(-1) // final_out.size(-1) | |
assert actual_output_stride == encoder_output_stride, f"Actual output stride [{actual_output_stride}] does not equal desired output stride [{encoder_output_stride}]" | |
print(f"Confirmed encoder output stride {encoder_output_stride} !") | |
encoder.output_stride = encoder_output_stride | |
if in_channels != 3: | |
encoder = change_num_input_channels(encoder, in_channels) | |
return encoder | |