VistaDream / ops /gs /train.py
hpwang's picture
[Init]
fd5e0f7
raw
history blame
3.86 kB
import cv2
import tqdm
import torch
# import lpips
import numpy as np
from ops import utils
import torch.nn.functional as F
import torchvision.transforms as tvtf
from ops.gs.basic import Gaussian_Scene,Frame
from torchmetrics.image import StructuralSimilarityIndexMeasure
class RGB_Loss():
def __init__(self,w_lpips=0.2,w_ssim=0.2):
self.rgb_loss = F.smooth_l1_loss
# self.lpips_alex = lpips.LPIPS(net='alex').to('cuda')
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to('cuda')
self.w_ssim = w_ssim
self.w_lpips = w_lpips
def __call__(self,pr,gt,valid_mask=None):
pr = torch.nan_to_num(pr)
gt = torch.nan_to_num(gt)
if len(pr.shape) < 3: pr = pr[:,:,None].repeat(1,1,3)
if len(gt.shape) < 3: gt = gt[:,:,None].repeat(1,1,3)
pr_valid = pr[valid_mask] if valid_mask is not None else pr.reshape(-1,pr.shape[-1])
gt_valid = gt[valid_mask] if valid_mask is not None else gt.reshape(-1,gt.shape[-1])
l_rgb = self.rgb_loss(pr_valid,gt_valid)
l_ssim = 1.0 - self.ssim(pr[None].permute(0, 3, 1, 2), gt[None].permute(0, 3, 1, 2))
# l_lpips = self.lpips_alex(pr[None].permute(0, 3, 1, 2), gt[None].permute(0, 3, 1, 2))
return l_rgb + self.w_ssim * l_ssim
class GS_Train_Tool():
'''
Frames and well-trained gaussians are kept, refine the trainable gaussians
The supervision comes from the Frames of GS_Scene
'''
def __init__(self,
GS:Gaussian_Scene,
iters = 100) -> None:
self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
# hyperparameters for prune, densify, and update
self.lr_factor = 1.00
self.lr_update = 0.99
# learning rate
self.rgb_lr = 0.0005
self.xyz_lr = 0.0001
self.scale_lr = 0.005
self.opacity_lr = 0.05
self.rotation_lr = 0.001
# GSs for training
self.GS = GS
# hyperparameters for training
self.iters = iters
self._init_optimizer()
self.rgb_lossfunc = RGB_Loss(w_lpips=0)
def _init_optimizer(self):
self.optimize_frames = [gf for gf in self.GS.gaussian_frames if gf.rgb.requires_grad]
# following https://github.com/pointrix-project/msplat
self.optimizer = torch.optim.Adam([
{'params': [gf.xyz for gf in self.optimize_frames], 'lr': self.xyz_lr},
{'params': [gf.rgb for gf in self.optimize_frames], 'lr': self.rgb_lr},
{'params': [gf.scale for gf in self.optimize_frames], 'lr': self.scale_lr},
{'params': [gf.opacity for gf in self.optimize_frames], 'lr': self.opacity_lr},
{'params': [gf.rotation for gf in self.optimize_frames], 'lr': self.rotation_lr}
])
def _render(self,frame):
rgb,dpt,alpha = self.GS._render_RGBD(frame)
return rgb,dpt,alpha
def _to_cuda(self,tensor):
tensor = torch.from_numpy(tensor.astype(np.float32)).to('cuda')
return tensor
def __call__(self,target_frames=None):
target_frames = self.GS.frames if target_frames is None else target_frames
for iter in tqdm.tqdm(range(self.iters)):
frame_idx = np.random.randint(0,len(target_frames))
frame :Frame = target_frames[frame_idx]
render_rgb,render_dpt,render_alpha=self._render(frame)
loss_rgb = self.rgb_lossfunc(render_rgb,self._to_cuda(frame.rgb),valid_mask=frame.inpaint)
# optimization
loss = loss_rgb
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
refined_scene = self.GS
for gf in refined_scene.gaussian_frames:
gf._require_grad(False)
return refined_scene