File size: 6,999 Bytes
82ea528
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
import random
import tqdm

import torch
import torch.nn as nn
import torch.nn.functional as F
from kiui.mesh_utils import clean_mesh, decimate_mesh
from kiui.mesh_utils import laplacian_smooth_loss, normal_consistency
from pytorch_msssim import SSIM, MS_SSIM

import comfy.utils

from .diff_mesh_renderer import DiffRastRenderer

from shared_utils.camera_utils import BaseCameraController
from shared_utils.image_utils import prepare_torch_img

class DiffMeshCameraController(BaseCameraController):
    
    def get_render_result(self, render_pose, bg_color, **kwargs):
        ref_cam = (render_pose, self.cam.perspective)
        return self.renderer.render(*ref_cam, self.cam.H, self.cam.W, ssaa=1, bg_color=bg_color, **kwargs) #ssaa = min(2.0, max(0.125, 2 * np.random.random()))

class DiffMesh:
    
    def __init__(
        self, 
        mesh, 
        training_iterations, 
        batch_size, 
        texture_learning_rate, 
        train_mesh_geometry, 
        geometry_learning_rate, 
        ms_ssim_loss_weight, 
        remesh_after_n_iteration, 
        invert_bg_prob, 
        force_cuda_rasterize
    ):
        self.device = torch.device("cuda")
        
        self.train_mesh_geometry = train_mesh_geometry
        self.remesh_after_n_iteration = remesh_after_n_iteration
        
        # prepare main components for optimization
        self.renderer = DiffRastRenderer(mesh, force_cuda_rasterize).to(self.device)

        self.optimizer = torch.optim.Adam(self.renderer.get_params(texture_learning_rate, train_mesh_geometry, geometry_learning_rate))
        #self.ssim_loss = SSIM(data_range=1, size_average=True, channel=3)
        self.ms_ssim_loss = MS_SSIM(data_range=1, size_average=True, channel=3)
        self.lambda_ssim = ms_ssim_loss_weight
        
        self.training_iterations = training_iterations
        
        self.batch_size = batch_size
        
        self.invert_bg_prob = invert_bg_prob
    
    def prepare_training(self, reference_images, reference_masks, reference_orbit_camera_poses, reference_orbit_camera_fovy):
        self.ref_imgs_num = len(reference_images)
    
        self.ref_size_H = reference_images[0].shape[0]
        self.ref_size_W = reference_images[0].shape[1]
        
        # default camera settings
        self.cam_controller = DiffMeshCameraController(
            self.renderer, self.ref_size_W, self.ref_size_H, reference_orbit_camera_fovy, self.invert_bg_prob, None, self.device
        )

        self.all_ref_cam_poses = reference_orbit_camera_poses
        
        # prepare reference images and masks
        ref_imgs_torch_list = []
        ref_masks_torch_list = []
        for i in range(self.ref_imgs_num):
            ref_imgs_torch_list.append(prepare_torch_img(reference_images[i].unsqueeze(0), self.ref_size_H, self.ref_size_W, self.device))
            ref_masks_torch_list.append(prepare_torch_img(reference_masks[i].unsqueeze(2).unsqueeze(0), self.ref_size_H, self.ref_size_W, self.device))
            
        self.ref_imgs_torch = torch.cat(ref_imgs_torch_list, dim=0)
        self.ref_masks_torch = torch.cat(ref_masks_torch_list, dim=0)
    
    def training(self, decimate_target=5e4):
        starter = torch.cuda.Event(enable_timing=True)
        ender = torch.cuda.Event(enable_timing=True)
        starter.record()
        
        ref_imgs_masked = []
        for i in range(self.ref_imgs_num):
            ref_imgs_masked.append((self.ref_imgs_torch[i] * self.ref_masks_torch[i]).unsqueeze(0))
            
        ref_imgs_num_minus_1 = self.ref_imgs_num-1
        
        comfy_pbar = comfy.utils.ProgressBar(self.training_iterations)

        for step in tqdm.trange(self.training_iterations):

            ### calculate loss between reference and rendered image from known view
            loss = 0
            masked_rendered_img_batch = []
            masked_ref_img_batch = []
            for _ in range(self.batch_size):
                
                i = random.randint(0, ref_imgs_num_minus_1)

                out = self.cam_controller.render_at_pose(self.all_ref_cam_poses[i])                

                image = out["image"]    # [H, W, 3] in [0, 1]
                image = image.permute(2, 0, 1).contiguous()  # [3, H, W] in [0, 1]
                
                image_masked = (image * self.ref_masks_torch[i]).unsqueeze(0)
                
                masked_rendered_img_batch.append(image_masked)
                masked_ref_img_batch.append(ref_imgs_masked[i])
            
            masked_rendered_img_batch_torch = torch.cat(masked_rendered_img_batch, dim=0)
            masked_ref_img_batch_torch = torch.cat(masked_ref_img_batch, dim=0)
                
            # rgb loss
            loss += (1 - self.lambda_ssim) * F.mse_loss(masked_rendered_img_batch_torch, masked_ref_img_batch_torch)
            
            # D-SSIM loss
            # [1, 3, H, W] in [0, 1]
            #loss += self.lambda_ssim * (1 - self.ssim_loss(X, Y))
            loss += self.lambda_ssim * (1 - self.ms_ssim_loss(masked_ref_img_batch_torch, masked_rendered_img_batch_torch))
            
            # Regularization loss
            if self.train_mesh_geometry:
                current_v = self.renderer.mesh.v + self.renderer.v_offsets
                loss += 0.01 * laplacian_smooth_loss(current_v, self.renderer.mesh.f)
                loss += 0.001 * normal_consistency(current_v, self.renderer.mesh.f)
                loss += 0.1 * (self.renderer.v_offsets ** 2).sum(-1).mean()
                
                # remesh periodically
                if step > 0 and step % self.remesh_after_n_iteration == 0:
                    vertices = (self.renderer.mesh.v + self.renderer.v_offsets).detach().cpu().numpy()
                    triangles = self.renderer.mesh.f.detach().cpu().numpy()
                    vertices, triangles = clean_mesh(vertices, triangles, remesh=True, remesh_size=0.01)
                    if triangles.shape[0] > decimate_target:
                        vertices, triangles = decimate_mesh(vertices, triangles, decimate_target, optimalplacement=False)
                    self.renderer.mesh.v = torch.from_numpy(vertices).contiguous().float().to(self.device)
                    self.renderer.mesh.f = torch.from_numpy(triangles).contiguous().int().to(self.device)
                    self.renderer.v_offsets = nn.Parameter(torch.zeros_like(self.renderer.mesh.v)).to(self.device)

            # optimize step
            loss.backward()
            self.optimizer.step()
            self.optimizer.zero_grad()
            
            comfy_pbar.update_absolute(step + 1)
            
        torch.cuda.synchronize()
            
        self.need_update = True
            
        print(f"Step: {step}")

        self.renderer.update_mesh()
        
        ender.record()
        #t = starter.elapsed_time(ender)
        
    def get_mesh_and_texture(self):
        return (self.renderer.mesh, self.renderer.mesh.albedo, )