File size: 3,587 Bytes
ffbcf9e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.

import numpy as np
import torch
import torch.nn as nn

from collections import OrderedDict
from networks.layers import upsample, ConvBlock, Conv3x3

from einops import rearrange


class DepthDecoder(nn.Module):
    def __init__(self, cfg, num_ch_enc, num_output_channels=1, use_skips=True):
        super(DepthDecoder, self).__init__()

        self.cfg = cfg
        depth_num = cfg.model.gaussians_per_pixel - 1 if "unidepth" in cfg.model.name else cfg.model.gaussians_per_pixel
        self.num_output_channels = num_output_channels * depth_num
        self.use_skips = use_skips
        self.upsample_mode = 'nearest'
        self.scales = cfg.model.scales

        self.num_ch_enc = num_ch_enc
        self.num_ch_dec = np.array([16, 32, 64, 128, 256])

        # decoder
        self.convs = OrderedDict()
        for i in range(4, -1, -1):
            # upconv_0
            num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)

            # upconv_1
            num_ch_in = self.num_ch_dec[i]
            if self.use_skips and i > 0:
                num_ch_in += self.num_ch_enc[i - 1]
            num_ch_out = self.num_ch_dec[i]
            self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)

        for s in self.scales:
            out = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
            self.convs[("dispconv", s)] = out
            nn.init.xavier_uniform_(out.conv.weight, cfg.model.depth_scale)
            nn.init.constant_(out.conv.bias, cfg.model.depth_bias)

        self.decoder = nn.ModuleList(list(self.convs.values()))
        if cfg.model.depth_type in ["disp", "disp_inc"]:
            self.activate = nn.Sigmoid()
        elif cfg.model.depth_type == "depth":
            self.activate = nn.Softplus()
        elif cfg.model.depth_type == "depth_inc":
            self.activate = torch.exp

    def forward(self, input_features):
        outputs = {}
        x = input_features[-1]
        for i in range(4, -1, -1):
            x = self.convs[("upconv", i, 0)](x)
            x = [upsample(x)]
            if self.use_skips and i > 0:
                x += [input_features[i - 1]]
            x = torch.cat(x, 1)
            x = self.convs[("upconv", i, 1)](x)
            if i in self.scales:
                depth_num = self.cfg.model.gaussians_per_pixel - 1 if "unidepth" in self.cfg.model.name else self.cfg.model.gaussians_per_pixel
                if self.cfg.model.depth_type == "depth_inc":
                    outputs[("depth", i)] = rearrange(self.activate(torch.clamp(self.convs[("dispconv", i)](x), min=-10.0, max=6.0)),
                                                 'b (n c) ...-> (b n) c ...', n = depth_num)
                elif self.cfg.model.depth_type in ["disp", "disp_inc"]:
                    outputs[("disp", i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
                                                 'b (n c) ...-> (b n) c ...', n = depth_num)
                else:
                    outputs[(self.cfg.model.depth_type, i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
                                                 'b (n c) ...-> (b n) c ...', n = depth_num)
        return outputs