qitaoz's picture
init commit
4545610 verified
raw
history blame
9.51 kB
#
# Copyright (C) 2023, Inria
# GRAPHDECO research group, https://team.inria.fr/graphdeco
# All rights reserved.
#
# This software is free for non-commercial, research and evaluation use
# under the terms of the LICENSE.md file.
#
# For inquiries contact [email protected]
#
from typing import NamedTuple
import torch.nn as nn
import torch
from . import _C
def cpu_deep_copy_tuple(input_tuple):
copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple]
return tuple(copied_tensors)
def rasterize_gaussians(
means3D,
means2D,
sh,
colors_precomp,
opacities,
scales,
rotations,
cov3Ds_precomp,
viewmat,
raster_settings,
):
return _RasterizeGaussians.apply(
means3D,
means2D,
sh,
colors_precomp,
opacities,
scales,
rotations,
cov3Ds_precomp,
viewmat,
raster_settings,
)
class _RasterizeGaussians(torch.autograd.Function):
@staticmethod
def forward(
ctx,
means3D,
means2D,
sh,
colors_precomp,
opacities,
scales,
rotations,
cov3Ds_precomp,
viewmat,
raster_settings,
):
# Restructure arguments the way that the C++ lib expects them
args = (
raster_settings.bg,
means3D,
colors_precomp,
opacities,
scales,
rotations,
raster_settings.scale_modifier,
cov3Ds_precomp,
raster_settings.viewmatrix,
raster_settings.projmatrix,
raster_settings.tanfovx,
raster_settings.tanfovy,
raster_settings.image_height,
raster_settings.image_width,
sh,
raster_settings.sh_degree,
raster_settings.campos,
raster_settings.prefiltered,
raster_settings.debug
)
# Invoke C++/CUDA rasterizer
if raster_settings.debug:
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
try:
num_rendered, color, depth, alpha, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
except Exception as ex:
torch.save(cpu_args, "snapshot_fw.dump")
print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.")
raise ex
else:
num_rendered, color, depth, alpha, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args)
# Keep relevant tensors for backward
ctx.raster_settings = raster_settings
ctx.num_rendered = num_rendered
ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, alpha)
return color, radii, depth, alpha
@staticmethod
def backward(ctx, grad_color, grad_radii, grad_depth, grad_alpha):
# Restore necessary values from context
num_rendered = ctx.num_rendered
raster_settings = ctx.raster_settings
colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, alpha = ctx.saved_tensors
# Restructure args as C++ method expects them
args = (raster_settings.bg,
means3D,
radii,
colors_precomp,
scales,
rotations,
raster_settings.scale_modifier,
cov3Ds_precomp,
raster_settings.viewmatrix,
raster_settings.projmatrix,
raster_settings.tanfovx,
raster_settings.tanfovy,
grad_color,
grad_depth,
grad_alpha,
sh,
raster_settings.sh_degree,
raster_settings.campos,
geomBuffer,
num_rendered,
binningBuffer,
imgBuffer,
alpha,
raster_settings.debug)
# Compute gradients for relevant tensors by invoking backward method
if raster_settings.debug:
cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted
try:
grad_means2D, grad_ts, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
except Exception as ex:
torch.save(cpu_args, "snapshot_bw.dump")
print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n")
raise ex
else:
grad_means2D, grad_ts, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args)
with torch.no_grad():
# return viewmat gradients
projmat = raster_settings.projmatrix.T
means_h = torch.cat([means3D, torch.ones_like(means3D[..., :1])], dim=-1)
p_hom = torch.einsum("ij,nj->ni", projmat, means_h)
rw = 1 / (p_hom[..., 3] + 1e-5)
proj = raster_settings.perspectivematrix.flatten()
# v_t is the grad w.r.t. the 3D mean in camera coordinates
# Math reference (https://arxiv.org/pdf/2312.02121.pdf)
# One source is from grad_means2D (the grad w.r.t. the mean in ND coordinates, t' in the paper)
v_tx = grad_means2D[:, 0] * (proj[0] * rw - proj[3] * p_hom[:, 0] * torch.square(rw))
v_tx += grad_means2D[:, 1] * (proj[1] * rw - proj[3] * p_hom[:, 1] * torch.square(rw))
v_ty = grad_means2D[:, 0] * (proj[4] * rw - proj[7] * p_hom[:, 0] * torch.square(rw))
v_ty += grad_means2D[:, 1] * (proj[5] * rw - proj[7] * p_hom[:, 1] * torch.square(rw))
v_tz = grad_means2D[:, 0] * (proj[8] * rw - proj[11] * p_hom[:, 0] * torch.square(rw))
v_tz += grad_means2D[:, 1] * (proj[9] * rw - proj[11] * p_hom[:, 1] * torch.square(rw))
v_t = torch.stack(
[
v_tx,
v_ty,
v_tz,
torch.zeros_like(v_tx),
],
dim=-1,
)
# Another source of gradients (grad_ts)
# t is involved in the affine transform J when computing the 2D covariance matrix
# grad_ts is gathered from "cuda_rasterizer/backward.cu"
v_t[:, :3] += grad_ts
# Finally, we compute the grad w.r.t. the viewmatrix from v_t
grad_viewmat = torch.einsum("ni,nj->ij", v_t, means_h).T # We transposed the viewmatrix
grads = (
grad_means3D,
grad_means2D,
grad_sh,
grad_colors_precomp,
grad_opacities,
grad_scales,
grad_rotations,
grad_cov3Ds_precomp,
grad_viewmat,
None,
)
return grads
class GaussianRasterizationSettings(NamedTuple):
image_height: int
image_width: int
tanfovx : float
tanfovy : float
bg : torch.Tensor
scale_modifier : float
viewmatrix : torch.Tensor
perspectivematrix: torch.Tensor
projmatrix : torch.Tensor
sh_degree : int
campos : torch.Tensor
prefiltered : bool
debug : bool
class GaussianRasterizer(nn.Module):
def __init__(self, raster_settings):
super().__init__()
self.raster_settings = raster_settings
def markVisible(self, positions):
# Mark visible points (based on frustum culling for camera) with a boolean
with torch.no_grad():
raster_settings = self.raster_settings
visible = _C.mark_visible(
positions,
raster_settings.viewmatrix,
raster_settings.projmatrix)
return visible
def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None, viewmat = None):
raster_settings = self.raster_settings
if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None):
raise Exception('Please provide excatly one of either SHs or precomputed colors!')
if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None):
raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!')
if shs is None:
shs = torch.Tensor([])
if colors_precomp is None:
colors_precomp = torch.Tensor([])
if scales is None:
scales = torch.Tensor([])
if rotations is None:
rotations = torch.Tensor([])
if cov3D_precomp is None:
cov3D_precomp = torch.Tensor([])
# Invoke C++/CUDA rasterization routine
return rasterize_gaussians(
means3D,
means2D,
shs,
colors_precomp,
opacities,
scales,
rotations,
cov3D_precomp,
viewmat,
raster_settings,
)