Spaces:
Sleeping
Sleeping
File size: 4,008 Bytes
98a77e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 |
# Copyright (c) 2020-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION, its affiliates and licensors retain all intellectual
# property and proprietary rights in and to this material, related
# documentation and any modifications thereto. Any use, reproduction,
# disclosure or distribution of this material and related documentation
# without an express license agreement from NVIDIA CORPORATION or
# its affiliates is strictly prohibited.
import torch
from ..render import mesh
from ..render import render
from ..render import regularizer
###############################################################################
# Geometry interface
###############################################################################
class DLMesh(torch.nn.Module):
def __init__(self, initial_guess, FLAGS):
super(DLMesh, self).__init__()
self.FLAGS = FLAGS
self.initial_guess = initial_guess
self.mesh = initial_guess.clone()
print("Base mesh has %d triangles and %d vertices." % (self.mesh.t_pos_idx.shape[0], self.mesh.v_pos.shape[0]))
self.mesh.v_pos = torch.nn.Parameter(self.mesh.v_pos, requires_grad=True)
self.register_parameter('vertex_pos', self.mesh.v_pos)
@torch.no_grad()
def getAABB(self):
return mesh.aabb(self.mesh)
def getMesh(self, material):
self.mesh.material = material
imesh = mesh.Mesh(base=self.mesh)
# Compute normals and tangent space
imesh = mesh.auto_normals(imesh)
imesh = mesh.compute_tangents(imesh)
return imesh
def render(self, glctx, target, lgt, opt_material, bsdf=None):
opt_mesh = self.getMesh(opt_material)
return render.render_mesh(glctx, opt_mesh, target['mvp'], target['campos'], lgt, target['resolution'], spp=target['spp'],
num_layers=self.FLAGS.layers, msaa=True, background=target['background'], bsdf=bsdf)
def tick(self, glctx, target, lgt, opt_material, loss_fn, iteration):
# ==============================================================================================
# Render optimizable object with identical conditions
# ==============================================================================================
buffers = self.render(glctx, target, lgt, opt_material)
# ==============================================================================================
# Compute loss
# ==============================================================================================
t_iter = iteration / self.FLAGS.iter
# Image-space loss, split into a coverage component and a color component
color_ref = target['img']
img_loss = torch.nn.functional.mse_loss(buffers['shaded'][..., 3:], color_ref[..., 3:])
img_loss += loss_fn(buffers['shaded'][..., 0:3] * color_ref[..., 3:], color_ref[..., 0:3] * color_ref[..., 3:])
reg_loss = torch.tensor([0], dtype=torch.float32, device="cuda")
# Compute regularizer.
if self.FLAGS.laplace == "absolute":
reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter)
elif self.FLAGS.laplace == "relative":
reg_loss += regularizer.laplace_regularizer_const(self.mesh.v_pos - self.initial_guess.v_pos, self.mesh.t_pos_idx) * self.FLAGS.laplace_scale * (1 - t_iter)
# Albedo (k_d) smoothnesss regularizer
reg_loss += torch.mean(buffers['kd_grad'][..., :-1] * buffers['kd_grad'][..., -1:]) * 0.03 * min(1.0, iteration / 500)
# Visibility regularizer
reg_loss += torch.mean(buffers['occlusion'][..., :-1] * buffers['occlusion'][..., -1:]) * 0.001 * min(1.0, iteration / 500)
# Light white balance regularizer
reg_loss = reg_loss + lgt.regularizer() * 0.005
return img_loss, reg_loss |