Flash3d / flash3d /networks /gaussian_decoder.py
Ryukijano's picture
commit the whole flash3d
ffbcf9e verified
from collections import OrderedDict
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
def upsample(x):
"""Upsample input tensor by a factor of 2
"""
return F.interpolate(x, scale_factor=2, mode="nearest")
class Conv3x3(nn.Module):
"""Layer to pad and convolve input
"""
def __init__(self, in_channels, out_channels, use_refl=True):
super(Conv3x3, self).__init__()
if use_refl:
self.pad = nn.ReflectionPad2d(1)
else:
self.pad = nn.ZeroPad2d(1)
self.conv = nn.Conv2d(int(in_channels), int(out_channels), 3)
def forward(self, x):
out = self.pad(x)
out = self.conv(out)
return out
class ConvBlock(nn.Module):
"""Layer to perform a convolution followed by ELU
"""
def __init__(self, in_channels, out_channels):
super(ConvBlock, self).__init__()
self.conv = Conv3x3(in_channels, out_channels)
self.nonlin = nn.ELU(inplace=True)
def forward(self, x):
out = self.conv(x)
out = self.nonlin(out)
return out
def get_splits_and_inits(cfg):
split_dimensions = []
scale_inits = []
bias_inits = []
for g_idx in range(cfg.model.gaussians_per_pixel):
if cfg.model.predict_offset:
split_dimensions += [3]
scale_inits += [cfg.model.xyz_scale]
bias_inits += [cfg.model.xyz_bias]
split_dimensions += [1, 3, 4, 3]
scale_inits += [cfg.model.opacity_scale,
cfg.model.scale_scale,
1.0,
5.0]
bias_inits += [cfg.model.opacity_bias,
np.log(cfg.model.scale_bias),
0.0,
0.0]
if cfg.model.max_sh_degree != 0:
sh_num = (cfg.model.max_sh_degree + 1) ** 2 - 1
sh_num_rgb = sh_num * 3
split_dimensions.append(sh_num_rgb)
scale_inits.append(cfg.model.sh_scale)
bias_inits.append(0.0)
if not cfg.model.one_gauss_decoder:
break
return split_dimensions, scale_inits, bias_inits,
class GaussianDecoder(nn.Module):
def __init__(self, cfg, num_ch_enc, use_skips=True):
super(GaussianDecoder, self).__init__()
self.cfg = cfg
self.use_skips = use_skips
self.upsample_mode = 'nearest'
self.num_ch_enc = num_ch_enc
self.num_ch_dec = np.array(cfg.model.num_ch_dec)
split_dimensions, scale, bias = get_splits_and_inits(cfg)
# [offset], opacity, scaling, rotation, feat_dc
assert not cfg.model.unified_decoder
self.split_dimensions = split_dimensions
self.num_output_channels = sum(self.split_dimensions)
# decoder
self.convs = OrderedDict()
for i in range(4, -1, -1):
# upconv_0
num_ch_in = self.num_ch_enc[-1] if i == 4 else self.num_ch_dec[i + 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 0)] = ConvBlock(num_ch_in, num_ch_out)
# upconv_1
num_ch_in = self.num_ch_dec[i]
if self.use_skips and i > 0:
num_ch_in += self.num_ch_enc[i - 1]
num_ch_out = self.num_ch_dec[i]
self.convs[("upconv", i, 1)] = ConvBlock(num_ch_in, num_ch_out)
self.out = nn.Conv2d(self.num_ch_dec[0], self.num_output_channels, 1)
out_channels = self.split_dimensions
start_channels = 0
for out_channel, b, s in zip(out_channels, bias, scale):
nn.init.xavier_uniform_(
self.out.weight[start_channels:start_channels+out_channel,
:, :, :], s)
nn.init.constant_(
self.out.bias[start_channels:start_channels+out_channel], b)
start_channels += out_channel
self.decoder = nn.ModuleList(list(self.convs.values()))
self.scaling_activation = torch.exp
self.opacity_activation = torch.sigmoid
self.rotation_activation = torch.nn.functional.normalize
self.scaling_lambda = cfg.model.scale_lambda
self.sigmoid = nn.Sigmoid()
def forward(self, input_features):
self.outputs = {}
# decoder
x = input_features[-1]
for i in range(4, -1, -1):
x = self.convs[("upconv", i, 0)](x)
x = [upsample(x)]
if self.use_skips and i > 0:
x += [input_features[i - 1]]
x = torch.cat(x, 1)
x = self.convs[("upconv", i, 1)](x)
x = self.out(x)
split_network_outputs = x.split(self.split_dimensions, dim=1)
offset_list = []
opacity_list = []
scaling_list = []
rotation_list = []
feat_dc_list = []
feat_rest_list = []
assert not self.cfg.model.unified_decoder
for i in range(self.cfg.model.gaussians_per_pixel):
assert self.cfg.model.max_sh_degree > 0
if self.cfg.model.predict_offset:
offset_s, opacity_s, scaling_s, \
rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*6:(i+1)*6]
offset_list.append(offset_s[:, None, ...])
else:
opacity_s, scaling_s, rotation_s, feat_dc_s, features_rest_s = split_network_outputs[i*5:(i+1)*5]
opacity_list.append(opacity_s[:, None, ...])
scaling_list.append(scaling_s[:, None, ...])
rotation_list.append(rotation_s[:, None, ...])
feat_dc_list.append(feat_dc_s[:, None, ...])
feat_rest_list.append(features_rest_s[:, None, ...])
if not self.cfg.model.one_gauss_decoder:
break
# squeezing will remove dimension if there is only one gaussian per pixel
opacity = torch.cat(opacity_list, dim=1).squeeze(1)
scaling = torch.cat(scaling_list, dim=1).squeeze(1)
rotation = torch.cat(rotation_list, dim=1).squeeze(1)
feat_dc = torch.cat(feat_dc_list, dim=1).squeeze(1)
features_rest = torch.cat(feat_rest_list, dim=1).squeeze(1)
out = {
("gauss_opacity", 0): self.opacity_activation(opacity),
("gauss_scaling", 0): self.scaling_activation(scaling) * self.scaling_lambda,
("gauss_rotation", 0): self.rotation_activation(rotation),
("gauss_features_dc", 0): feat_dc,
("gauss_features_rest", 0): features_rest
}
if self.cfg.model.predict_offset:
offset = torch.cat(offset_list, dim=1).squeeze(1)
out[("gauss_offset", 0)] = offset
return out