Spaces:
Runtime error
Runtime error
import torch | |
from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s | |
class ResNetBackbone(torch.nn.Module): | |
def __init__( | |
self, backbone="resnet50", pretrained_base=True, dilated=True, **kwargs | |
): | |
super(ResNetBackbone, self).__init__() | |
if backbone == "resnet34": | |
pretrained = resnet34_v1b( | |
pretrained=pretrained_base, dilated=dilated, **kwargs | |
) | |
elif backbone == "resnet50": | |
pretrained = resnet50_v1s( | |
pretrained=pretrained_base, dilated=dilated, **kwargs | |
) | |
elif backbone == "resnet101": | |
pretrained = resnet101_v1s( | |
pretrained=pretrained_base, dilated=dilated, **kwargs | |
) | |
elif backbone == "resnet152": | |
pretrained = resnet152_v1s( | |
pretrained=pretrained_base, dilated=dilated, **kwargs | |
) | |
else: | |
raise RuntimeError(f"unknown backbone: {backbone}") | |
self.conv1 = pretrained.conv1 | |
self.bn1 = pretrained.bn1 | |
self.relu = pretrained.relu | |
self.maxpool = pretrained.maxpool | |
self.layer1 = pretrained.layer1 | |
self.layer2 = pretrained.layer2 | |
self.layer3 = pretrained.layer3 | |
self.layer4 = pretrained.layer4 | |
def forward(self, x, additional_features=None): | |
x = self.conv1(x) | |
x = self.bn1(x) | |
x = self.relu(x) | |
if additional_features is not None: | |
x = x + torch.nn.functional.pad( | |
additional_features, | |
[0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)], | |
mode="constant", | |
value=0, | |
) | |
x = self.maxpool(x) | |
c1 = self.layer1(x) | |
c2 = self.layer2(c1) | |
c3 = self.layer3(c2) | |
c4 = self.layer4(c3) | |
return c1, c2, c3, c4 | |