File size: 1,893 Bytes
2cdd41c
1615d09
2cdd41c
 
 
 
1615d09
 
 
2cdd41c
 
1615d09
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2cdd41c
1615d09
2cdd41c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1615d09
 
 
 
 
 
2cdd41c
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import torch

from .resnetv1b import resnet34_v1b, resnet50_v1s, resnet101_v1s, resnet152_v1s


class ResNetBackbone(torch.nn.Module):
    def __init__(
        self, backbone="resnet50", pretrained_base=True, dilated=True, **kwargs
    ):
        super(ResNetBackbone, self).__init__()

        if backbone == "resnet34":
            pretrained = resnet34_v1b(
                pretrained=pretrained_base, dilated=dilated, **kwargs
            )
        elif backbone == "resnet50":
            pretrained = resnet50_v1s(
                pretrained=pretrained_base, dilated=dilated, **kwargs
            )
        elif backbone == "resnet101":
            pretrained = resnet101_v1s(
                pretrained=pretrained_base, dilated=dilated, **kwargs
            )
        elif backbone == "resnet152":
            pretrained = resnet152_v1s(
                pretrained=pretrained_base, dilated=dilated, **kwargs
            )
        else:
            raise RuntimeError(f"unknown backbone: {backbone}")

        self.conv1 = pretrained.conv1
        self.bn1 = pretrained.bn1
        self.relu = pretrained.relu
        self.maxpool = pretrained.maxpool
        self.layer1 = pretrained.layer1
        self.layer2 = pretrained.layer2
        self.layer3 = pretrained.layer3
        self.layer4 = pretrained.layer4

    def forward(self, x, additional_features=None):
        x = self.conv1(x)
        x = self.bn1(x)
        x = self.relu(x)
        if additional_features is not None:
            x = x + torch.nn.functional.pad(
                additional_features,
                [0, 0, 0, 0, 0, x.size(1) - additional_features.size(1)],
                mode="constant",
                value=0,
            )
        x = self.maxpool(x)
        c1 = self.layer1(x)
        c2 = self.layer2(c1)
        c3 = self.layer3(c2)
        c4 = self.layer4(c3)

        return c1, c2, c3, c4