File size: 6,491 Bytes
fd5e0f7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
'''
Coarse Gaussian Rendering -- RGB-D as init
RGB-D add noise (MV init)
Cycling:
    denoise to x0 and d0 -- optimize Gaussian
    re-rendering RGB-D
    render RGB-D to rectified noise
    noise rectification
    step denoise with rectified noise
-- Finally the Gaussian
'''
import torch
import numpy as np
from copy import deepcopy
from ops.utils import *
from ops.gs.train import *
from ops.trajs import _generate_trajectory
from ops.gs.basic import Frame,Gaussian_Scene

class Refinement_Tool_MCS():
    def __init__(self,
                 coarse_GS:Gaussian_Scene,
                 device = 'cuda',
                 refiner = None,
                 traj_type = 'spiral',
                 n_view = 8,
                 rect_w = 0.7,
                 n_gsopt_iters = 256) -> None:
        # input coarse GS
        # refine frames to be refined; here we refine frames rather than gaussian paras
        self.n_view = n_view
        self.rect_w = rect_w
        self.n_gsopt_iters = n_gsopt_iters
        self.coarse_GS = coarse_GS
        self.refine_frames: list[Frame] = []
        # hyperparameters total is 50 steps and here is the last N steps
        self.process_res = 512
        self.device = device
        self.traj_type = traj_type
        # models
        self.RGB_LCM = refiner
        self.RGB_LCM.to('cuda')
        self.steps = self.RGB_LCM.denoise_steps
        # prompt for diffusion
        prompt = self.coarse_GS.frames[-1].prompt
        self.rgb_prompt_latent = self.RGB_LCM.model._encode_text_prompt(prompt)
        # loss function
        self.rgb_lossfunc = RGB_Loss(w_ssim=0.2)

    def _pre_process(self): 
        # determine the diffusion target shape
        strict_times = 32
        origin_H = self.coarse_GS.frames[0].H
        origin_W = self.coarse_GS.frames[0].W
        self.target_H,self.target_W = self.process_res,self.process_res
        # reshape to the same (target) shape for rendering and denoising
        intrinsic = deepcopy(self.coarse_GS.frames[0].intrinsic)
        H_ratio, W_ratio = self.target_H/origin_H, self.target_W/origin_W
        intrinsic[0] *= W_ratio
        intrinsic[1] *= H_ratio 
        target_H, target_W = self.target_H+2*strict_times, self.target_W+2*strict_times
        intrinsic[0,-1] = target_W/2
        intrinsic[1,-1] = target_H/2
        # generate a set of cameras
        trajs = _generate_trajectory(None,self.coarse_GS,nframes=self.n_view+2)[1:-1]
        for i, pose in enumerate(trajs):
            fine_frame = Frame()
            fine_frame.H = target_H
            fine_frame.W = target_W
            fine_frame.extrinsic = pose
            fine_frame.intrinsic = deepcopy(intrinsic)
            fine_frame.prompt  = self.coarse_GS.frames[-1].prompt
            self.refine_frames.append(fine_frame) 
        # determine inpaint mask
        temp_scene = Gaussian_Scene()
        temp_scene._add_trainable_frame(self.coarse_GS.frames[0],require_grad=False)
        temp_scene._add_trainable_frame(self.coarse_GS.frames[1],require_grad=False)
        for frame in self.refine_frames:
            frame = temp_scene._render_for_inpaint(frame)
            
    def _mv_init(self):
        rgbs = []
        # only for inpainted images
        for frame in self.refine_frames:
            # rendering at now; all in the same shape
            render_rgb,render_dpt,render_alpha=self.coarse_GS._render_RGBD(frame)
            # diffusion images
            rgbs.append(render_rgb.permute(2,0,1)[None])
        self.rgbs = torch.cat(rgbs,dim=0)
        self.RGB_LCM._encode_mv_init_images(self.rgbs)

    def _to_cuda(self,tensor):
        tensor = torch.from_numpy(tensor.astype(np.float32)).to('cuda')
        return tensor

    def _x0_rectification(self, denoise_rgb, iters):
        # gaussian initialization
        CGS = deepcopy(self.coarse_GS)
        for gf in CGS.gaussian_frames:
            gf._require_grad(True)
        self.refine_GS = GS_Train_Tool(CGS)
        # rectification
        for iter in range(iters):
            loss = 0.
            # supervise on input view
            for i in range(2):
                keep_frame :Frame = self.coarse_GS.frames[i]
                render_rgb,render_dpt,render_alpha = self.refine_GS._render(keep_frame)
                loss_rgb = self.rgb_lossfunc(render_rgb,self._to_cuda(keep_frame.rgb),valid_mask=keep_frame.inpaint)
                loss += loss_rgb*len(self.refine_frames)
            # then multiview supervision
            for i,frame in enumerate(self.refine_frames):
                render_rgb,render_dpt,render_alpha = self.refine_GS._render(frame)
                loss_rgb_item = self.rgb_lossfunc(denoise_rgb[i],render_rgb)
                loss += loss_rgb_item
            # optimization
            loss.backward()  
            self.refine_GS.optimizer.step()
            self.refine_GS.optimizer.zero_grad()
        
    def _step_gaussian_optimization(self,step):
        # denoise to x0 and d0
        with torch.no_grad():
            # we left the last 2 steps for stronger guidances
            rgb_t = self.RGB_LCM.timesteps[-self.steps+step]
            rgb_t = torch.tensor([rgb_t]).to(self.device)
            rgb_noise_pr,rgb_denoise = self.RGB_LCM._denoise_to_x0(rgb_t,self.rgb_prompt_latent)
            rgb_denoise = rgb_denoise.permute(0,2,3,1)
        # rendering each frames and weight-able refinement
        self._x0_rectification(rgb_denoise,self.n_gsopt_iters)      
        return rgb_t, rgb_noise_pr

    def _step_diffusion_rectification(self, rgb_t, rgb_noise_pr):
        # re-rendering RGB
        with torch.no_grad():
            x0_rect = []
            for i,frame in enumerate(self.refine_frames):
                re_render_rgb,_,re_render_alpha= self.refine_GS._render(frame)
                # avoid rasterization holes yield more block holes and more
                x0_rect.append(re_render_rgb.permute(2,0,1)[None])
            x0_rect = torch.cat(x0_rect,dim=0)
        # rectification
        self.RGB_LCM._step_denoise(rgb_t,rgb_noise_pr,x0_rect,rect_w=self.rect_w) 

    def __call__(self):
        # warmup
        self._pre_process()
        self._mv_init()
        for step in tqdm.tqdm(range(self.steps)):
            rgb_t, rgb_noise_pr = self._step_gaussian_optimization(step)
            self._step_diffusion_rectification(rgb_t, rgb_noise_pr)
        scene = self.refine_GS.GS
        for gf in scene.gaussian_frames:
            gf._require_grad(False)
        return scene