jaxmetaverse's picture
Upload folder using huggingface_hub
82ea528 verified
import random
import tqdm
import torch
import torch.nn as nn
import torch.nn.functional as F
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency
from pytorch_msssim import SSIM, MS_SSIM
import comfy.utils
from .diff_mesh_renderer import DiffRastRenderer
from shared_utils.camera_utils import BaseCameraController
from shared_utils.image_utils import prepare_torch_img
class DiffMeshCameraController(BaseCameraController):
def get_render_result(self, render_pose, bg_color, **kwargs):
ref_cam = (render_pose, self.cam.perspective)
return self.renderer.render(*ref_cam, self.cam.H, self.cam.W, ssaa=1, bg_color=bg_color, **kwargs) #ssaa = min(2.0, max(0.125, 2 * np.random.random()))
class DiffMesh:
def __init__(
self,
mesh,
training_iterations,
batch_size,
texture_learning_rate,
train_mesh_geometry,
geometry_learning_rate,
ms_ssim_loss_weight,
remesh_after_n_iteration,
invert_bg_prob,
force_cuda_rasterize
):
self.device = torch.device("cuda")
self.train_mesh_geometry = train_mesh_geometry
self.remesh_after_n_iteration = remesh_after_n_iteration
# prepare main components for optimization
self.renderer = DiffRastRenderer(mesh, force_cuda_rasterize).to(self.device)
self.optimizer = torch.optim.Adam(self.renderer.get_params(texture_learning_rate, train_mesh_geometry, geometry_learning_rate))
#self.ssim_loss = SSIM(data_range=1, size_average=True, channel=3)
self.ms_ssim_loss = MS_SSIM(data_range=1, size_average=True, channel=3)
self.lambda_ssim = ms_ssim_loss_weight
self.training_iterations = training_iterations
self.batch_size = batch_size
self.invert_bg_prob = invert_bg_prob
def prepare_training(self, reference_images, reference_masks, reference_orbit_camera_poses, reference_orbit_camera_fovy):
self.ref_imgs_num = len(reference_images)
self.ref_size_H = reference_images[0].shape[0]
self.ref_size_W = reference_images[0].shape[1]
# default camera settings
self.cam_controller = DiffMeshCameraController(
self.renderer, self.ref_size_W, self.ref_size_H, reference_orbit_camera_fovy, self.invert_bg_prob, None, self.device
)
self.all_ref_cam_poses = reference_orbit_camera_poses
# prepare reference images and masks
ref_imgs_torch_list = []
ref_masks_torch_list = []
for i in range(self.ref_imgs_num):
ref_imgs_torch_list.append(prepare_torch_img(reference_images[i].unsqueeze(0), self.ref_size_H, self.ref_size_W, self.device))
ref_masks_torch_list.append(prepare_torch_img(reference_masks[i].unsqueeze(2).unsqueeze(0), self.ref_size_H, self.ref_size_W, self.device))
self.ref_imgs_torch = torch.cat(ref_imgs_torch_list, dim=0)
self.ref_masks_torch = torch.cat(ref_masks_torch_list, dim=0)
def training(self, decimate_target=5e4):
starter = torch.cuda.Event(enable_timing=True)
ender = torch.cuda.Event(enable_timing=True)
starter.record()
ref_imgs_masked = []
for i in range(self.ref_imgs_num):
ref_imgs_masked.append((self.ref_imgs_torch[i] * self.ref_masks_torch[i]).unsqueeze(0))
ref_imgs_num_minus_1 = self.ref_imgs_num-1
comfy_pbar = comfy.utils.ProgressBar(self.training_iterations)
for step in tqdm.trange(self.training_iterations):
### calculate loss between reference and rendered image from known view
loss = 0
masked_rendered_img_batch = []
masked_ref_img_batch = []
for _ in range(self.batch_size):
i = random.randint(0, ref_imgs_num_minus_1)
out = self.cam_controller.render_at_pose(self.all_ref_cam_poses[i])
image = out["image"] # [H, W, 3] in [0, 1]
image = image.permute(2, 0, 1).contiguous() # [3, H, W] in [0, 1]
image_masked = (image * self.ref_masks_torch[i]).unsqueeze(0)
masked_rendered_img_batch.append(image_masked)
masked_ref_img_batch.append(ref_imgs_masked[i])
masked_rendered_img_batch_torch = torch.cat(masked_rendered_img_batch, dim=0)
masked_ref_img_batch_torch = torch.cat(masked_ref_img_batch, dim=0)
# rgb loss
loss += (1 - self.lambda_ssim) * F.mse_loss(masked_rendered_img_batch_torch, masked_ref_img_batch_torch)
# D-SSIM loss
# [1, 3, H, W] in [0, 1]
#loss += self.lambda_ssim * (1 - self.ssim_loss(X, Y))
loss += self.lambda_ssim * (1 - self.ms_ssim_loss(masked_ref_img_batch_torch, masked_rendered_img_batch_torch))
# Regularization loss
if self.train_mesh_geometry:
current_v = self.renderer.mesh.v + self.renderer.v_offsets
loss += 0.01 * laplacian_smooth_loss(current_v, self.renderer.mesh.f)
loss += 0.001 * normal_consistency(current_v, self.renderer.mesh.f)
loss += 0.1 * (self.renderer.v_offsets ** 2).sum(-1).mean()
# remesh periodically
if step > 0 and step % self.remesh_after_n_iteration == 0:
vertices = (self.renderer.mesh.v + self.renderer.v_offsets).detach().cpu().numpy()
triangles = self.renderer.mesh.f.detach().cpu().numpy()
vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
if triangles.shape[0] > decimate_target:
vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
self.renderer.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
self.renderer.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
self.renderer.v_offsets = nn.Parameter(torch.zeros_like(self.renderer.mesh.v)).to(self.device)
# optimize step
loss.backward()
self.optimizer.step()
self.optimizer.zero_grad()
comfy_pbar.update_absolute(step + 1)
torch.cuda.synchronize()
self.need_update = True
print(f"Step: {step}")
self.renderer.update_mesh()
ender.record()
#t = starter.elapsed_time(ender)
def get_mesh_and_texture(self):
return (self.renderer.mesh, self.renderer.mesh.albedo, )