Flash3d / flash3d /networks /depth_decoder.py
Ryukijano's picture
commit the whole flash3d
ffbcf9e verified
# Copyright Niantic 2019. Patent Pending. All rights reserved.
#
# This software is licensed under the terms of the Monodepth2 licence
# which allows for non-commercial use only, the full terms of which are made
# available in the LICENSE file.
import numpy as np
import torch
import torch.nn as nn
from collections import OrderedDict
from networks.layers import upsample, ConvBlock, Conv3x3
from einops import rearrange
class DepthDecoder(nn.Module):
def __init__(self, cfg, num_ch_enc, num_output_channels=1, use_skips=True):
super(DepthDecoder, self).__init__()
self.cfg = cfg
depth_num = cfg.model.gaussians_per_pixel - 1 if "unidepth" in cfg.model.name else cfg.model.gaussians_per_pixel
self.num_output_channels = num_output_channels * depth_num
self.use_skips = use_skips
self.upsample_mode = 'nearest'
self.scales = cfg.model.scales
self.num_ch_enc = num_ch_enc
self.num_ch_dec = np.array([16, 32, 64, 128, 256])
# decoder
self.convs = OrderedDict()
for i in range(4, -1, -1):
# upconv_0
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
# upconv_1
num_ch_in = self.num_ch_dec[i]
if self.use_skips and i > 0:
num_ch_in += self.num_ch_enc[i - 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
for s in self.scales:
out = Conv3x3(self.num_ch_dec[s], self.num_output_channels)
self.convs[("dispconv", s)] = out
nn.init.xavier_uniform_(out.conv.weight, cfg.model.depth_scale)
nn.init.constant_(out.conv.bias, cfg.model.depth_bias)
self.decoder = nn.ModuleList(list(self.convs.values()))
if cfg.model.depth_type in ["disp", "disp_inc"]:
self.activate = nn.Sigmoid()
elif cfg.model.depth_type == "depth":
self.activate = nn.Softplus()
elif cfg.model.depth_type == "depth_inc":
self.activate = torch.exp
def forward(self, input_features):
outputs = {}
x = input_features[-1]
for i in range(4, -1, -1):
x = self.convs[("upconv", i, 0)](x)
x = [upsample(x)]
if self.use_skips and i > 0:
x += [input_features[i - 1]]
x = torch.cat(x, 1)
x = self.convs[("upconv", i, 1)](x)
if i in self.scales:
depth_num = self.cfg.model.gaussians_per_pixel - 1 if "unidepth" in self.cfg.model.name else self.cfg.model.gaussians_per_pixel
if self.cfg.model.depth_type == "depth_inc":
outputs[("depth", i)] = rearrange(self.activate(torch.clamp(self.convs[("dispconv", i)](x), min=-10.0, max=6.0)),
'b (n c) ...-> (b n) c ...', n = depth_num)
elif self.cfg.model.depth_type in ["disp", "disp_inc"]:
outputs[("disp", i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
'b (n c) ...-> (b n) c ...', n = depth_num)
else:
outputs[(self.cfg.model.depth_type, i)] = rearrange(self.activate(self.convs[("dispconv", i)](x)),
'b (n c) ...-> (b n) c ...', n = depth_num)
return outputs