|
|
|
|
|
|
|
|
|
|
|
|
|
import numpy as np |
|
import torch |
|
import torch.nn as nn |
|
|
|
from collections import OrderedDict |
|
from networks.layers import upsample, ConvBlock, Conv3x3 |
|
|
|
from einops import rearrange |
|
|
|
|
|
class DepthDecoder(nn.Module): |
|
def __init__(self, cfg, num_ch_enc, num_output_channels=1, use_skips=True): |
|
super(DepthDecoder, self).__init__() |
|
|
|
self.cfg = cfg |
|
depth_num = cfg.model.gaussians_per_pixel - 1 if "unidepth" in cfg.model.name else cfg.model.gaussians_per_pixel |
|
self.num_output_channels = num_output_channels * depth_num |
|
self.use_skips = use_skips |
|
self.upsample_mode = 'nearest' |
|
self.scales = cfg.model.scales |
|
|
|
self.num_ch_enc = num_ch_enc |
|
self.num_ch_dec = np.array([16, 32, 64, 128, 256]) |
|
|
|
|
|
self.convs = OrderedDict() |
|
for i in range(4, -1, -1): |
|
|
|
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1] |
|
num_ch_out = self.num_ch_dec[i] |
|
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out) |
|
|
|
|
|
num_ch_in = self.num_ch_dec[i] |
|
if self.use_skips and i > 0: |
|
num_ch_in += self.num_ch_enc[i - 1] |
|
num_ch_out = self.num_ch_dec[i] |
|
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out) |
|
|
|
for s in self.scales: |
|
out = Conv3x3(self.num_ch_dec[s], self.num_output_channels) |
|
self.convs[("dispconv", s)] = out |
|
nn.init.xavier_uniform_(out.conv.weight, cfg.model.depth_scale) |
|
nn.init.constant_(out.conv.bias, cfg.model.depth_bias) |
|
|
|
self.decoder = nn.ModuleList(list(self.convs.values())) |
|
if cfg.model.depth_type in ["disp", "disp_inc"]: |
|
self.activate = nn.Sigmoid() |
|
elif cfg.model.depth_type == "depth": |
|
self.activate = nn.Softplus() |
|
elif cfg.model.depth_type == "depth_inc": |
|
self.activate = torch.exp |
|
|
|
def forward(self, input_features): |
|
outputs = {} |
|
x = input_features[-1] |
|
for i in range(4, -1, -1): |
|
x = self.convs[("upconv", i, 0)](x) |
|
x = [upsample(x)] |
|
if self.use_skips and i > 0: |
|
x += [input_features[i - 1]] |
|
x = torch.cat(x, 1) |
|
x = self.convs[("upconv", i, 1)](x) |
|
if i in self.scales: |
|
depth_num = self.cfg.model.gaussians_per_pixel - 1 if "unidepth" in self.cfg.model.name else self.cfg.model.gaussians_per_pixel |
|
if self.cfg.model.depth_type == "depth_inc": |
|
outputs[("depth", i)] = rearrange(self.activate(torch.clamp(self.convs[("dispconv", i)](x), min=-10.0, max=6.0)), |
|
'b (n c) ...-> (b n) c ...', n = depth_num) |
|
elif self.cfg.model.depth_type in ["disp", "disp_inc"]: |
|
outputs[("disp", i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)), |
|
'b (n c) ...-> (b n) c ...', n = depth_num) |
|
else: |
|
outputs[(self.cfg.model.depth_type, i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)), |
|
'b (n c) ...-> (b n) c ...', n = depth_num) |
|
return outputs |
|
|