# # Copyright (C) 2023, Inria # GRAPHDECO research group, https://team.inria.fr/graphdeco # All rights reserved. # # This software is free for non-commercial, research and evaluation use # under the terms of the LICENSE.md file. # # For inquiries contact george.drettakis@inria.fr # from typing import NamedTuple import torch.nn as nn import torch from . import _C def cpu_deep_copy_tuple(input_tuple): copied_tensors = [item.cpu().clone() if isinstance(item, torch.Tensor) else item for item in input_tuple] return tuple(copied_tensors) def rasterize_gaussians( means3D, means2D, sh, colors_precomp, opacities, scales, rotations, cov3Ds_precomp, viewmat, raster_settings, ): return _RasterizeGaussians.apply( means3D, means2D, sh, colors_precomp, opacities, scales, rotations, cov3Ds_precomp, viewmat, raster_settings, ) class _RasterizeGaussians(torch.autograd.Function): @staticmethod def forward( ctx, means3D, means2D, sh, colors_precomp, opacities, scales, rotations, cov3Ds_precomp, viewmat, raster_settings, ): # Restructure arguments the way that the C++ lib expects them args = ( raster_settings.bg, means3D, colors_precomp, opacities, scales, rotations, raster_settings.scale_modifier, cov3Ds_precomp, raster_settings.viewmatrix, raster_settings.projmatrix, raster_settings.tanfovx, raster_settings.tanfovy, raster_settings.image_height, raster_settings.image_width, sh, raster_settings.sh_degree, raster_settings.campos, raster_settings.prefiltered, raster_settings.debug ) # Invoke C++/CUDA rasterizer if raster_settings.debug: cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted try: num_rendered, color, depth, alpha, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) except Exception as ex: torch.save(cpu_args, "snapshot_fw.dump") print("\nAn error occured in forward. Please forward snapshot_fw.dump for debugging.") raise ex else: num_rendered, color, depth, alpha, radii, geomBuffer, binningBuffer, imgBuffer = _C.rasterize_gaussians(*args) # Keep relevant tensors for backward ctx.raster_settings = raster_settings ctx.num_rendered = num_rendered ctx.save_for_backward(colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, alpha) return color, radii, depth, alpha @staticmethod def backward(ctx, grad_color, grad_radii, grad_depth, grad_alpha): # Restore necessary values from context num_rendered = ctx.num_rendered raster_settings = ctx.raster_settings colors_precomp, means3D, scales, rotations, cov3Ds_precomp, radii, sh, geomBuffer, binningBuffer, imgBuffer, alpha = ctx.saved_tensors # Restructure args as C++ method expects them args = (raster_settings.bg, means3D, radii, colors_precomp, scales, rotations, raster_settings.scale_modifier, cov3Ds_precomp, raster_settings.viewmatrix, raster_settings.projmatrix, raster_settings.tanfovx, raster_settings.tanfovy, grad_color, grad_depth, grad_alpha, sh, raster_settings.sh_degree, raster_settings.campos, geomBuffer, num_rendered, binningBuffer, imgBuffer, alpha, raster_settings.debug) # Compute gradients for relevant tensors by invoking backward method if raster_settings.debug: cpu_args = cpu_deep_copy_tuple(args) # Copy them before they can be corrupted try: grad_means2D, grad_ts, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) except Exception as ex: torch.save(cpu_args, "snapshot_bw.dump") print("\nAn error occured in backward. Writing snapshot_bw.dump for debugging.\n") raise ex else: grad_means2D, grad_ts, grad_colors_precomp, grad_opacities, grad_means3D, grad_cov3Ds_precomp, grad_sh, grad_scales, grad_rotations = _C.rasterize_gaussians_backward(*args) with torch.no_grad(): # return viewmat gradients projmat = raster_settings.projmatrix.T means_h = torch.cat([means3D, torch.ones_like(means3D[..., :1])], dim=-1) p_hom = torch.einsum("ij,nj->ni", projmat, means_h) rw = 1 / (p_hom[..., 3] + 1e-5) proj = raster_settings.perspectivematrix.flatten() # v_t is the grad w.r.t. the 3D mean in camera coordinates # Math reference (https://arxiv.org/pdf/2312.02121.pdf) # One source is from grad_means2D (the grad w.r.t. the mean in ND coordinates, t' in the paper) v_tx = grad_means2D[:, 0] * (proj[0] * rw - proj[3] * p_hom[:, 0] * torch.square(rw)) v_tx += grad_means2D[:, 1] * (proj[1] * rw - proj[3] * p_hom[:, 1] * torch.square(rw)) v_ty = grad_means2D[:, 0] * (proj[4] * rw - proj[7] * p_hom[:, 0] * torch.square(rw)) v_ty += grad_means2D[:, 1] * (proj[5] * rw - proj[7] * p_hom[:, 1] * torch.square(rw)) v_tz = grad_means2D[:, 0] * (proj[8] * rw - proj[11] * p_hom[:, 0] * torch.square(rw)) v_tz += grad_means2D[:, 1] * (proj[9] * rw - proj[11] * p_hom[:, 1] * torch.square(rw)) v_t = torch.stack( [ v_tx, v_ty, v_tz, torch.zeros_like(v_tx), ], dim=-1, ) # Another source of gradients (grad_ts) # t is involved in the affine transform J when computing the 2D covariance matrix # grad_ts is gathered from "cuda_rasterizer/backward.cu" v_t[:, :3] += grad_ts # Finally, we compute the grad w.r.t. the viewmatrix from v_t grad_viewmat = torch.einsum("ni,nj->ij", v_t, means_h).T # We transposed the viewmatrix grads = ( grad_means3D, grad_means2D, grad_sh, grad_colors_precomp, grad_opacities, grad_scales, grad_rotations, grad_cov3Ds_precomp, grad_viewmat, None, ) return grads class GaussianRasterizationSettings(NamedTuple): image_height: int image_width: int tanfovx : float tanfovy : float bg : torch.Tensor scale_modifier : float viewmatrix : torch.Tensor perspectivematrix: torch.Tensor projmatrix : torch.Tensor sh_degree : int campos : torch.Tensor prefiltered : bool debug : bool class GaussianRasterizer(nn.Module): def __init__(self, raster_settings): super().__init__() self.raster_settings = raster_settings def markVisible(self, positions): # Mark visible points (based on frustum culling for camera) with a boolean with torch.no_grad(): raster_settings = self.raster_settings visible = _C.mark_visible( positions, raster_settings.viewmatrix, raster_settings.projmatrix) return visible def forward(self, means3D, means2D, opacities, shs = None, colors_precomp = None, scales = None, rotations = None, cov3D_precomp = None, viewmat = None): raster_settings = self.raster_settings if (shs is None and colors_precomp is None) or (shs is not None and colors_precomp is not None): raise Exception('Please provide excatly one of either SHs or precomputed colors!') if ((scales is None or rotations is None) and cov3D_precomp is None) or ((scales is not None or rotations is not None) and cov3D_precomp is not None): raise Exception('Please provide exactly one of either scale/rotation pair or precomputed 3D covariance!') if shs is None: shs = torch.Tensor([]) if colors_precomp is None: colors_precomp = torch.Tensor([]) if scales is None: scales = torch.Tensor([]) if rotations is None: rotations = torch.Tensor([]) if cov3D_precomp is None: cov3D_precomp = torch.Tensor([]) # Invoke C++/CUDA rasterization routine return rasterize_gaussians( means3D, means2D, shs, colors_precomp, opacities, scales, rotations, cov3D_precomp, viewmat, raster_settings, )