Spaces:
Build error
Build error
import cv2 | |
import tqdm | |
import torch | |
# import lpips | |
import numpy as np | |
from ops import utils | |
import torch.nn.functional as F | |
import torchvision.transforms as tvtf | |
from ops.gs.basic import Gaussian_Scene,Frame | |
from torchmetrics.image import StructuralSimilarityIndexMeasure | |
class RGB_Loss(): | |
def __init__(self,w_lpips=0.2,w_ssim=0.2): | |
self.rgb_loss = F.smooth_l1_loss | |
# self.lpips_alex = lpips.LPIPS(net='alex').to('cuda') | |
self.ssim = StructuralSimilarityIndexMeasure(data_range=1.0).to('cuda') | |
self.w_ssim = w_ssim | |
self.w_lpips = w_lpips | |
def __call__(self,pr,gt,valid_mask=None): | |
pr = torch.nan_to_num(pr) | |
gt = torch.nan_to_num(gt) | |
if len(pr.shape) < 3: pr = pr[:,:,None].repeat(1,1,3) | |
if len(gt.shape) < 3: gt = gt[:,:,None].repeat(1,1,3) | |
pr_valid = pr[valid_mask] if valid_mask is not None else pr.reshape(-1,pr.shape[-1]) | |
gt_valid = gt[valid_mask] if valid_mask is not None else gt.reshape(-1,gt.shape[-1]) | |
l_rgb = self.rgb_loss(pr_valid,gt_valid) | |
l_ssim = 1.0 - self.ssim(pr[None].permute(0, 3, 1, 2), gt[None].permute(0, 3, 1, 2)) | |
# l_lpips = self.lpips_alex(pr[None].permute(0, 3, 1, 2), gt[None].permute(0, 3, 1, 2)) | |
return l_rgb + self.w_ssim * l_ssim | |
class GS_Train_Tool(): | |
''' | |
Frames and well-trained gaussians are kept, refine the trainable gaussians | |
The supervision comes from the Frames of GS_Scene | |
''' | |
def __init__(self, | |
GS:Gaussian_Scene, | |
iters = 100) -> None: | |
self.device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
# hyperparameters for prune, densify, and update | |
self.lr_factor = 1.00 | |
self.lr_update = 0.99 | |
# learning rate | |
self.rgb_lr = 0.0005 | |
self.xyz_lr = 0.0001 | |
self.scale_lr = 0.005 | |
self.opacity_lr = 0.05 | |
self.rotation_lr = 0.001 | |
# GSs for training | |
self.GS = GS | |
# hyperparameters for training | |
self.iters = iters | |
self._init_optimizer() | |
self.rgb_lossfunc = RGB_Loss(w_lpips=0) | |
def _init_optimizer(self): | |
self.optimize_frames = [gf for gf in self.GS.gaussian_frames if gf.rgb.requires_grad] | |
# following https://github.com/pointrix-project/msplat | |
self.optimizer = torch.optim.Adam([ | |
{'params': [gf.xyz for gf in self.optimize_frames], 'lr': self.xyz_lr}, | |
{'params': [gf.rgb for gf in self.optimize_frames], 'lr': self.rgb_lr}, | |
{'params': [gf.scale for gf in self.optimize_frames], 'lr': self.scale_lr}, | |
{'params': [gf.opacity for gf in self.optimize_frames], 'lr': self.opacity_lr}, | |
{'params': [gf.rotation for gf in self.optimize_frames], 'lr': self.rotation_lr} | |
]) | |
def _render(self,frame): | |
rgb,dpt,alpha = self.GS._render_RGBD(frame) | |
return rgb,dpt,alpha | |
def _to_cuda(self,tensor): | |
tensor = torch.from_numpy(tensor.astype(np.float32)).to('cuda') | |
return tensor | |
def __call__(self,target_frames=None): | |
target_frames = self.GS.frames if target_frames is None else target_frames | |
for iter in tqdm.tqdm(range(self.iters)): | |
frame_idx = np.random.randint(0,len(target_frames)) | |
frame :Frame = target_frames[frame_idx] | |
render_rgb,render_dpt,render_alpha=self._render(frame) | |
loss_rgb = self.rgb_lossfunc(render_rgb,self._to_cuda(frame.rgb),valid_mask=frame.inpaint) | |
# optimization | |
loss = loss_rgb | |
loss.backward() | |
self.optimizer.step() | |
self.optimizer.zero_grad() | |
refined_scene = self.GS | |
for gf in refined_scene.gaussian_frames: | |
gf._require_grad(False) | |
return refined_scene | |