curt-park's picture
Refactor code
1615d09
import torch
from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s
class ResNetBackbone(torch.nn.Module):
def __init__(
self, backbone="resnet50", pretrained_base=True, dilated=True, **kwargs
):
super(ResNetBackbone, self).__init__()
if backbone == "resnet34":
pretrained = resnet34_v1b(
pretrained=pretrained_base, dilated=dilated, **kwargs
)
elif backbone == "resnet50":
pretrained = resnet50_v1s(
pretrained=pretrained_base, dilated=dilated, **kwargs
)
elif backbone == "resnet101":
pretrained = resnet101_v1s(
pretrained=pretrained_base, dilated=dilated, **kwargs
)
elif backbone == "resnet152":
pretrained = resnet152_v1s(
pretrained=pretrained_base, dilated=dilated, **kwargs
)
else:
raise RuntimeError(f"unknown backbone: {backbone}")
self.conv1 = pretrained.conv1
self.bn1 = pretrained.bn1
self.relu = pretrained.relu
self.maxpool = pretrained.maxpool
self.layer1 = pretrained.layer1
self.layer2 = pretrained.layer2
self.layer3 = pretrained.layer3
self.layer4 = pretrained.layer4
def forward(self, x, additional_features=None):
x = self.conv1(x)
x = self.bn1(x)
x = self.relu(x)
if additional_features is not None:
x = x + torch.nn.functional.pad(
additional_features,
[0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)],
mode="constant",
value=0,
)
x = self.maxpool(x)
c1 = self.layer1(x)
c2 = self.layer2(c1)
c3 = self.layer3(c2)
c4 = self.layer4(c3)
return c1, c2, c3, c4