File size: 1,965 Bytes
c9b5796
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
# Copyright (c) Meta Platforms, Inc. and affiliates.

import torch.nn as nn
from torchvision.models.resnet import Bottleneck

from .base import BaseModel
from .feature_extractor import AdaptationBlock
from .utils import checkpointed


class BEVNet(BaseModel):
    default_conf = {
        "pretrained": True,
        "num_blocks": "???",
        "latent_dim": "???",
        "input_dim": "${.latent_dim}",
        "output_dim": "${.latent_dim}",
        "confidence": False,
        "norm_layer": "nn.BatchNorm2d",  # normalization ind decoder blocks
        "checkpointed": False,  # whether to use gradient checkpointing
        "padding": "zeros",
    }

    def _init(self, conf):
        blocks = []
        Block = checkpointed(Bottleneck, do=conf.checkpointed)
        for i in range(conf.num_blocks):
            dim = conf.input_dim if i == 0 else conf.latent_dim
            blocks.append(
                Block(
                    dim,
                    conf.latent_dim // Bottleneck.expansion,
                    norm_layer=eval(conf.norm_layer),
                )
            )
        self.blocks = nn.Sequential(*blocks)
        self.output_layer = AdaptationBlock(conf.latent_dim, conf.output_dim)
        if conf.confidence:
            self.confidence_layer = AdaptationBlock(conf.latent_dim, 1)

        def update_padding(module):
            if isinstance(module, nn.Conv2d):
                module.padding_mode = conf.padding

        if conf.padding != "zeros":
            self.bocks.apply(update_padding)

    def _forward(self, data):
        features = self.blocks(data["input"])
        pred = {
            "output": self.output_layer(features),
        }
        if self.conf.confidence:
            pred["confidence"] = self.confidence_layer(features).squeeze(1).sigmoid()
        return pred

    def loss(self, pred, data):
        raise NotImplementedError

    def metrics(self, pred, data):
        raise NotImplementedError