VistaDream / ops /gs /basic.py
hpwang's picture
[Init]
fd5e0f7
import PIL
import torch
import numpy as np
import gsplat as gs
import torch.nn as nn
from copy import deepcopy
import torch.nn.functional as F
from dataclasses import dataclass
from ops.utils import (
dpt2xyz,
alpha_inpaint_mask,
transform_points,
numpy_normalize,
numpy_quaternion_from_matrix
)
@dataclass
class Frame():
'''
rgb: in shape of H*W*3, in range of 0-1
dpt: in shape of H*W, real depth
inpaint: bool mask in shape of H*W for inpainting
intrinsic: 3*3
extrinsic: array in shape of 4*4
As a class for:
initialize camera
accept rendering result
accept inpainting result
All at 2D-domain
'''
def __init__(self,
H: int = None,
W: int = None,
rgb: np.array = None,
dpt: np.array = None,
sky: np.array = None,
inpaint: np.array = None,
intrinsic: np.array = None,
extrinsic: np.array = None,
# detailed target
ideal_dpt: np.array = None,
ideal_nml: np.array = None,
prompt: str = None) -> None:
self.H = H
self.W = W
self.rgb = rgb
self.dpt = dpt
self.sky = sky
self.prompt = prompt
self.intrinsic = intrinsic
self.extrinsic = extrinsic
self._rgb_rect()
self._extr_rect()
# for inpainting
self.inpaint = inpaint
self.inpaint_wo_edge = inpaint
# for supervision
self.ideal_dpt = ideal_dpt
self.ideal_nml = ideal_nml
def _rgb_rect(self):
if self.rgb is not None:
if isinstance(self.rgb, PIL.PngImagePlugin.PngImageFile):
self.rgb = np.array(self.rgb)
if isinstance(self.rgb, PIL.JpegImagePlugin.JpegImageFile):
self.rgb = np.array(self.rgb)
if np.amax(self.rgb) > 1.1:
self.rgb = self.rgb / 255
def _extr_rect(self):
if self.extrinsic is None: self.extrinsic = np.eye(4)
self.inv_extrinsic = np.linalg.inv(self.extrinsic)
@dataclass
class Gaussian_Frame():
'''
In-frame-frustrum
Gaussians from a single RGBD frame
As a class for:
accept information from initialized/inpainting+geo-estimated frame
saving pixelsplat properties including rgb, xyz, scale, rotation, opacity; note here, we made a modification to xyz;
we first project depth to xyz
then we tune a scale map(initialized to ones) and a shift map(initialized to zeros), they are optimized and add to the original xyz when rendering
'''
# as pixelsplat guassian
rgb: torch.Tensor = None,
scale: torch.Tensor = None,
opacity: torch.Tensor = None,
rotation: torch.Tensor = None,
# gaussian center
dpt: torch.Tensor = None,
xyz: torch.Tensor = None,
# as a frame
H: int = 480,
W: int = 640,
def __init__(self, frame: Frame, device = 'cuda'):
'''after inpainting'''
# de-active functions
self.rgbs_deact = torch.logit
self.scales_deact = torch.log
self.opacity_deact = torch.logit
self.device = device
# for gaussian initialization
self._set_property_from_frame(frame)
def _to_3d(self):
# inv intrinsic
xyz = dpt2xyz(self.dpt,self.intrinsic)
inv_extrinsic = np.linalg.inv(self.extrinsic)
xyz = transform_points(xyz,inv_extrinsic)
return xyz
def _paint_filter(self,paint_mask):
if np.sum(paint_mask)<3:
paint_mask = np.zeros((self.H,self.W))
paint_mask[0:1] = 1
paint_mask = paint_mask>.5
self.rgb = self.rgb[paint_mask]
self.xyz = self.xyz[paint_mask]
self.scale = self.scale[paint_mask]
self.opacity = self.opacity[paint_mask]
self.rotation = self.rotation[paint_mask]
def _to_cuda(self):
self.rgb = torch.from_numpy(self.rgb.astype(np.float32)).to(self.device)
self.xyz = torch.from_numpy(self.xyz.astype(np.float32)).to(self.device)
self.scale = torch.from_numpy(self.scale.astype(np.float32)).to(self.device)
self.opacity = torch.from_numpy(self.opacity.astype(np.float32)).to(self.device)
self.rotation = torch.from_numpy(self.rotation.astype(np.float32)).to(self.device)
def _fine_init_scale_rotations(self):
# from https://arxiv.org/pdf/2406.09394
""" Compute rotation matrices that align z-axis with given normal vectors using matrix operations. """
up_axis = np.array([0,1,0])
nml = self.nml @ self.extrinsic[0:3,0:3]
qz = numpy_normalize(nml)
qx = np.cross(up_axis,qz)
qx = numpy_normalize(qx)
qy = np.cross(qz,qx)
qy = numpy_normalize(qy)
rot = np.concatenate([qx[...,None],qy[...,None],qz[...,None]],axis=-1)
self.rotation = numpy_quaternion_from_matrix(rot)
# scale
safe_nml = deepcopy(self.nml)
safe_nml[safe_nml[:,:,-1]<0.2,-1] = .2
normal_xoz = deepcopy(safe_nml)
normal_yoz = deepcopy(safe_nml)
normal_xoz[...,1] = 0.
normal_yoz[...,0] = 0.
normal_xoz = numpy_normalize(normal_xoz)
normal_yoz = numpy_normalize(normal_yoz)
cos_theta_x = np.abs(normal_xoz[...,2])
cos_theta_y = np.abs(normal_yoz[...,2])
scale_basic = self.dpt / self.intrinsic[0,0] / np.sqrt(2)
scale_x = scale_basic / cos_theta_x
scale_y = scale_basic / cos_theta_y
scale_z = (scale_x + scale_y) / 10.
self.scale = np.concatenate([scale_x[...,None],
scale_y[...,None],
scale_z[...,None]],axis=-1)
def _coarse_init_scale_rotations(self):
# gaussian property -- HW3 scale
self.scale = self.dpt / self.intrinsic[0,0] / np.sqrt(2)
self.scale = self.scale[:,:,None].repeat(3,-1)
# gaussian property -- HW4 rotation
self.rotation = np.zeros((self.H,self.W,4))
self.rotation[:,:,0] = 1.
def _set_property_from_frame(self,frame: Frame):
'''frame here is a complete init/inpainted frame'''
# basic frame-level property
self.H = frame.H
self.W = frame.W
self.dpt = frame.dpt
self.intrinsic = frame.intrinsic
self.extrinsic = frame.extrinsic
# gaussian property -- xyz with train-able pixel-aligned scale and shift
self.xyz = self._to_3d()
# gaussian property -- HW3 rgb
self.rgb = frame.rgb
# gaussian property -- HW4 rotation HW3 scale
self._coarse_init_scale_rotations()
# gaussian property -- HW opacity
self.opacity = np.ones((self.H,self.W,1)) * 0.8
# to cuda
self._paint_filter(frame.inpaint_wo_edge)
self._to_cuda()
# de-activate
self.rgb = self.rgbs_deact(self.rgb)
self.scale = self.scales_deact(self.scale)
self.opacity = self.opacity_deact(self.opacity)
# to torch parameters
self.rgb = nn.Parameter(self.rgb,requires_grad=False)
self.xyz = nn.Parameter(self.xyz,requires_grad=False)
self.scale = nn.Parameter(self.scale,requires_grad=False)
self.opacity = nn.Parameter(self.opacity,requires_grad=False)
self.rotation = nn.Parameter(self.rotation,requires_grad=False)
def _require_grad(self,sign=True):
self.rgb = self.rgb.requires_grad_(sign)
self.xyz = self.xyz.requires_grad_(sign)
self.scale = self.scale.requires_grad_(sign)
self.opacity = self.opacity.requires_grad_(sign)
self.rotation = self.rotation.requires_grad_(sign)
class Gaussian_Scene():
def __init__(self,cfg=None):
# frames initialing the frame
self.frames = []
self.gaussian_frames: list[Gaussian_Frame] = [] # gaussian frame require training at this optimization
# activate fuctions
self.rgbs_act = torch.sigmoid
self.scales_act = torch.exp
self.opacity_act = torch.sigmoid
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# for traj generation
self.traj_type = 'spiral'
if cfg is not None:
self.traj_min_percentage = cfg.scene.traj.near_percentage
self.traj_max_percentage = cfg.scene.traj.far_percentage
self.traj_forward_ratio = cfg.scene.traj.traj_forward_ratio
self.traj_backward_ratio = cfg.scene.traj.traj_backward_ratio
else:
self.traj_min_percentage,self.traj_max_percentage,self.traj_forward_ratio,self.traj_backward_ratio = 5, 50, 0.3, 0.4
# basic operations
def _render_RGBD(self,frame,background_color='black'):
'''
:intinsic: tensor of [fu,fv,cu,cv] 4-dimension
:extinsic: tensor 4*4-dimension
:out: tensor H*W*3-dimension
'''
background = None
if background_color =='white':
background = torch.ones(1,4,device=self.device)*0.1
background[:,-1] = 0. # for depth
# aligned untrainable xyz and unaligned trainable xyz
# others
xyz = torch.cat([gf.xyz.reshape(-1,3) for gf in self.gaussian_frames],dim=0)
rgb = torch.cat([gf.rgb.reshape(-1,3) for gf in self.gaussian_frames],dim=0)
scale = torch.cat([gf.scale.reshape(-1,3) for gf in self.gaussian_frames],dim=0)
opacity = torch.cat([gf.opacity.reshape(-1) for gf in self.gaussian_frames],dim=0)
rotation = torch.cat([gf.rotation.reshape(-1,4) for gf in self.gaussian_frames],dim=0)
# activate
rgb = self.rgbs_act(rgb)
scale = self.scales_act(scale)
rotation = F.normalize(rotation,dim=1)
opacity = self.opacity_act(opacity)
# property
H,W = frame.H, frame.W
intrinsic = torch.from_numpy(frame.intrinsic.astype(np.float32)).to(self.device)
extrinsic = torch.from_numpy(frame.extrinsic.astype(np.float32)).to(self.device)
# render
render_out,render_alpha,_ = gs.rendering.rasterization(means = xyz,
scales = scale,
quats = rotation,
opacities = opacity,
colors = rgb,
Ks = intrinsic[None],
viewmats = extrinsic[None],
width = W,
height = H,
packed = False,
near_plane= 0.01,
render_mode="RGB+ED",
backgrounds=background) # render: 1*H*W*(3+1)
render_out = render_out.squeeze() # result: H*W*(3+1)
render_rgb = render_out[:,:,0:3]
render_dpt = render_out[:,:,-1]
return render_rgb, render_dpt, render_alpha
@torch.no_grad()
def _render_for_inpaint(self,frame):
# first render
render_rgb, render_dpt, render_alpha = self._render_RGBD(frame)
render_msk = alpha_inpaint_mask(render_alpha)
# to numpy
render_rgb = render_rgb.detach().cpu().numpy()
render_dpt = render_dpt.detach().cpu().numpy()
render_alpha = render_alpha.detach().cpu().numpy()
# assign back
frame.rgb = render_rgb
frame.dpt = render_dpt
frame.inpaint = render_msk
return frame
def _add_trainable_frame(self,frame:Frame,require_grad=True):
# for the init frame, we keep all pixels for finetuning
self.frames.append(frame)
gf = Gaussian_Frame(frame, self.device)
gf._require_grad(require_grad)
self.gaussian_frames.append(gf)