Spaces:
Build error
Build error
File size: 6,491 Bytes
fd5e0f7 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 |
'''
Coarse Gaussian Rendering -- RGB-D as init
RGB-D add noise (MV init)
Cycling:
denoise to x0 and d0 -- optimize Gaussian
re-rendering RGB-D
render RGB-D to rectified noise
noise rectification
step denoise with rectified noise
-- Finally the Gaussian
'''
import torch
import numpy as np
from copy import deepcopy
from ops.utils import *
from ops.gs.train import *
from ops.trajs import _generate_trajectory
from ops.gs.basic import Frame,Gaussian_Scene
class Refinement_Tool_MCS():
def __init__(self,
coarse_GS:Gaussian_Scene,
device = 'cuda',
refiner = None,
traj_type = 'spiral',
n_view = 8,
rect_w = 0.7,
n_gsopt_iters = 256) -> None:
# input coarse GS
# refine frames to be refined; here we refine frames rather than gaussian paras
self.n_view = n_view
self.rect_w = rect_w
self.n_gsopt_iters = n_gsopt_iters
self.coarse_GS = coarse_GS
self.refine_frames: list[Frame] = []
# hyperparameters total is 50 steps and here is the last N steps
self.process_res = 512
self.device = device
self.traj_type = traj_type
# models
self.RGB_LCM = refiner
self.RGB_LCM.to('cuda')
self.steps = self.RGB_LCM.denoise_steps
# prompt for diffusion
prompt = self.coarse_GS.frames[-1].prompt
self.rgb_prompt_latent = self.RGB_LCM.model._encode_text_prompt(prompt)
# loss function
self.rgb_lossfunc = RGB_Loss(w_ssim=0.2)
def _pre_process(self):
# determine the diffusion target shape
strict_times = 32
origin_H = self.coarse_GS.frames[0].H
origin_W = self.coarse_GS.frames[0].W
self.target_H,self.target_W = self.process_res,self.process_res
# reshape to the same (target) shape for rendering and denoising
intrinsic = deepcopy(self.coarse_GS.frames[0].intrinsic)
H_ratio, W_ratio = self.target_H/origin_H, self.target_W/origin_W
intrinsic[0] *= W_ratio
intrinsic[1] *= H_ratio
target_H, target_W = self.target_H+2*strict_times, self.target_W+2*strict_times
intrinsic[0,-1] = target_W/2
intrinsic[1,-1] = target_H/2
# generate a set of cameras
trajs = _generate_trajectory(None,self.coarse_GS,nframes=self.n_view+2)[1:-1]
for i, pose in enumerate(trajs):
fine_frame = Frame()
fine_frame.H = target_H
fine_frame.W = target_W
fine_frame.extrinsic = pose
fine_frame.intrinsic = deepcopy(intrinsic)
fine_frame.prompt = self.coarse_GS.frames[-1].prompt
self.refine_frames.append(fine_frame)
# determine inpaint mask
temp_scene = Gaussian_Scene()
temp_scene._add_trainable_frame(self.coarse_GS.frames[0],require_grad=False)
temp_scene._add_trainable_frame(self.coarse_GS.frames[1],require_grad=False)
for frame in self.refine_frames:
frame = temp_scene._render_for_inpaint(frame)
def _mv_init(self):
rgbs = []
# only for inpainted images
for frame in self.refine_frames:
# rendering at now; all in the same shape
render_rgb,render_dpt,render_alpha=self.coarse_GS._render_RGBD(frame)
# diffusion images
rgbs.append(render_rgb.permute(2,0,1)[None])
self.rgbs = torch.cat(rgbs,dim=0)
self.RGB_LCM._encode_mv_init_images(self.rgbs)
def _to_cuda(self,tensor):
tensor = torch.from_numpy(tensor.astype(np.float32)).to('cuda')
return tensor
def _x0_rectification(self, denoise_rgb, iters):
# gaussian initialization
CGS = deepcopy(self.coarse_GS)
for gf in CGS.gaussian_frames:
gf._require_grad(True)
self.refine_GS = GS_Train_Tool(CGS)
# rectification
for iter in range(iters):
loss = 0.
# supervise on input view
for i in range(2):
keep_frame :Frame = self.coarse_GS.frames[i]
render_rgb,render_dpt,render_alpha = self.refine_GS._render(keep_frame)
loss_rgb = self.rgb_lossfunc(render_rgb,self._to_cuda(keep_frame.rgb),valid_mask=keep_frame.inpaint)
loss += loss_rgb*len(self.refine_frames)
# then multiview supervision
for i,frame in enumerate(self.refine_frames):
render_rgb,render_dpt,render_alpha = self.refine_GS._render(frame)
loss_rgb_item = self.rgb_lossfunc(denoise_rgb[i],render_rgb)
loss += loss_rgb_item
# optimization
loss.backward()
self.refine_GS.optimizer.step()
self.refine_GS.optimizer.zero_grad()
def _step_gaussian_optimization(self,step):
# denoise to x0 and d0
with torch.no_grad():
# we left the last 2 steps for stronger guidances
rgb_t = self.RGB_LCM.timesteps[-self.steps+step]
rgb_t = torch.tensor([rgb_t]).to(self.device)
rgb_noise_pr,rgb_denoise = self.RGB_LCM._denoise_to_x0(rgb_t,self.rgb_prompt_latent)
rgb_denoise = rgb_denoise.permute(0,2,3,1)
# rendering each frames and weight-able refinement
self._x0_rectification(rgb_denoise,self.n_gsopt_iters)
return rgb_t, rgb_noise_pr
def _step_diffusion_rectification(self, rgb_t, rgb_noise_pr):
# re-rendering RGB
with torch.no_grad():
x0_rect = []
for i,frame in enumerate(self.refine_frames):
re_render_rgb,_,re_render_alpha= self.refine_GS._render(frame)
# avoid rasterization holes yield more block holes and more
x0_rect.append(re_render_rgb.permute(2,0,1)[None])
x0_rect = torch.cat(x0_rect,dim=0)
# rectification
self.RGB_LCM._step_denoise(rgb_t,rgb_noise_pr,x0_rect,rect_w=self.rect_w)
def __call__(self):
# warmup
self._pre_process()
self._mv_init()
for step in tqdm.tqdm(range(self.steps)):
rgb_t, rgb_noise_pr = self._step_gaussian_optimization(step)
self._step_diffusion_rectification(rgb_t, rgb_noise_pr)
scene = self.refine_GS.GS
for gf in scene.gaussian_frames:
gf._require_grad(False)
return scene |