yejunliang23 commited on
Commit
1768e4a
·
verified ·
1 Parent(s): 41651f3

Create flexicube.py

Browse files
trellis/representations/mesh/flexicube.py ADDED
@@ -0,0 +1,359 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ #
3
+ # NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
4
+ # and proprietary rights in and to this software, related documentation
5
+ # and any modifications thereto. Any use, reproduction, disclosure or
6
+ # distribution of this software and related documentation without an express
7
+ # license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
8
+
9
+ import torch
10
+ from .tables import *
11
+
12
+ __all__ = [
13
+ 'FlexiCubes'
14
+ ]
15
+
16
+
17
+ class FlexiCubes:
18
+ def __init__(self, device="cuda"):
19
+
20
+ self.device = device
21
+ self.dmc_table = torch.tensor(dmc_table, dtype=torch.long, device=device, requires_grad=False)
22
+ self.num_vd_table = torch.tensor(num_vd_table,
23
+ dtype=torch.long, device=device, requires_grad=False)
24
+ self.check_table = torch.tensor(
25
+ check_table,
26
+ dtype=torch.long, device=device, requires_grad=False)
27
+
28
+ self.tet_table = torch.tensor(tet_table, dtype=torch.long, device=device, requires_grad=False)
29
+ self.quad_split_1 = torch.tensor([0, 1, 2, 0, 2, 3], dtype=torch.long, device=device, requires_grad=False)
30
+ self.quad_split_2 = torch.tensor([0, 1, 3, 3, 1, 2], dtype=torch.long, device=device, requires_grad=False)
31
+ self.quad_split_train = torch.tensor(
32
+ [0, 1, 1, 2, 2, 3, 3, 0], dtype=torch.long, device=device, requires_grad=False)
33
+
34
+ self.cube_corners = torch.tensor([[0, 0, 0], [1, 0, 0], [0, 1, 0], [1, 1, 0], [0, 0, 1], [
35
+ 1, 0, 1], [0, 1, 1], [1, 1, 1]], dtype=torch.float, device=device)
36
+ self.cube_corners_idx = torch.pow(2, torch.arange(8, requires_grad=False))
37
+ self.cube_edges = torch.tensor([0, 1, 1, 5, 4, 5, 0, 4, 2, 3, 3, 7, 6, 7, 2, 6,
38
+ 2, 0, 3, 1, 7, 5, 6, 4], dtype=torch.long, device=device, requires_grad=False)
39
+
40
+ self.edge_dir_table = torch.tensor([0, 2, 0, 2, 0, 2, 0, 2, 1, 1, 1, 1],
41
+ dtype=torch.long, device=device)
42
+ self.dir_faces_table = torch.tensor([
43
+ [[5, 4], [3, 2], [4, 5], [2, 3]],
44
+ [[5, 4], [1, 0], [4, 5], [0, 1]],
45
+ [[3, 2], [1, 0], [2, 3], [0, 1]]
46
+ ], dtype=torch.long, device=device)
47
+ self.adj_pairs = torch.tensor([0, 1, 1, 3, 3, 2, 2, 0], dtype=torch.long, device=device)
48
+
49
+ def __call__(self, voxelgrid_vertices, scalar_field, cube_idx, resolution, qef_reg_scale=1e-3,
50
+ weight_scale=0.99, beta=None, alpha=None, gamma_f=None, voxelgrid_colors=None, training=False):
51
+ surf_cubes, occ_fx8 = self._identify_surf_cubes(scalar_field, cube_idx)
52
+ if surf_cubes.sum() == 0:
53
+ return (
54
+ torch.zeros((0, 3), device=self.device),
55
+ torch.zeros((0, 3), dtype=torch.long, device=self.device),
56
+ torch.zeros((0), device=self.device),
57
+ torch.zeros((0, voxelgrid_colors.shape[-1]), device=self.device) if voxelgrid_colors is not None else None
58
+ )
59
+ beta, alpha, gamma_f = self._normalize_weights(
60
+ beta, alpha, gamma_f, surf_cubes, weight_scale)
61
+
62
+ if voxelgrid_colors is not None:
63
+ voxelgrid_colors = torch.sigmoid(voxelgrid_colors)
64
+
65
+ case_ids = self._get_case_id(occ_fx8, surf_cubes, resolution)
66
+
67
+ surf_edges, idx_map, edge_counts, surf_edges_mask = self._identify_surf_edges(
68
+ scalar_field, cube_idx, surf_cubes
69
+ )
70
+
71
+ vd, L_dev, vd_gamma, vd_idx_map, vd_color = self._compute_vd(
72
+ voxelgrid_vertices, cube_idx[surf_cubes], surf_edges, scalar_field,
73
+ case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors)
74
+ vertices, faces, s_edges, edge_indices, vertices_color = self._triangulate(
75
+ scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map,
76
+ vd_idx_map, surf_edges_mask, training, vd_color)
77
+ return vertices, faces, L_dev, vertices_color
78
+
79
+ def _compute_reg_loss(self, vd, ue, edge_group_to_vd, vd_num_edges):
80
+ """
81
+ Regularizer L_dev as in Equation 8
82
+ """
83
+ dist = torch.norm(ue - torch.index_select(input=vd, index=edge_group_to_vd, dim=0), dim=-1)
84
+ mean_l2 = torch.zeros_like(vd[:, 0])
85
+ mean_l2 = (mean_l2).index_add_(0, edge_group_to_vd, dist) / vd_num_edges.squeeze(1).float()
86
+ mad = (dist - torch.index_select(input=mean_l2, index=edge_group_to_vd, dim=0)).abs()
87
+ return mad
88
+
89
+ def _normalize_weights(self, beta, alpha, gamma_f, surf_cubes, weight_scale):
90
+ """
91
+ Normalizes the given weights to be non-negative. If input weights are None, it creates and returns a set of weights of ones.
92
+ """
93
+ n_cubes = surf_cubes.shape[0]
94
+
95
+ if beta is not None:
96
+ beta = (torch.tanh(beta) * weight_scale + 1)
97
+ else:
98
+ beta = torch.ones((n_cubes, 12), dtype=torch.float, device=self.device)
99
+
100
+ if alpha is not None:
101
+ alpha = (torch.tanh(alpha) * weight_scale + 1)
102
+ else:
103
+ alpha = torch.ones((n_cubes, 8), dtype=torch.float, device=self.device)
104
+
105
+ if gamma_f is not None:
106
+ gamma_f = torch.sigmoid(gamma_f) * weight_scale + (1 - weight_scale) / 2
107
+ else:
108
+ gamma_f = torch.ones((n_cubes), dtype=torch.float, device=self.device)
109
+
110
+ return beta[surf_cubes], alpha[surf_cubes], gamma_f[surf_cubes]
111
+
112
+ @torch.no_grad()
113
+ def _get_case_id(self, occ_fx8, surf_cubes, res):
114
+ """
115
+ Obtains the ID of topology cases based on cell corner occupancy. This function resolves the
116
+ ambiguity in the Dual Marching Cubes (DMC) configurations as described in Section 1.3 of the
117
+ supplementary material. It should be noted that this function assumes a regular grid.
118
+ """
119
+ case_ids = (occ_fx8[surf_cubes] * self.cube_corners_idx.to(self.device).unsqueeze(0)).sum(-1)
120
+
121
+ problem_config = self.check_table.to(self.device)[case_ids]
122
+ to_check = problem_config[..., 0] == 1
123
+ problem_config = problem_config[to_check]
124
+ if not isinstance(res, (list, tuple)):
125
+ res = [res, res, res]
126
+
127
+ # The 'problematic_configs' only contain configurations for surface cubes. Next, we construct a 3D array,
128
+ # 'problem_config_full', to store configurations for all cubes (with default config for non-surface cubes).
129
+ # This allows efficient checking on adjacent cubes.
130
+ problem_config_full = torch.zeros(list(res) + [5], device=self.device, dtype=torch.long)
131
+ vol_idx = torch.nonzero(problem_config_full[..., 0] == 0) # N, 3
132
+ vol_idx_problem = vol_idx[surf_cubes][to_check]
133
+ problem_config_full[vol_idx_problem[..., 0], vol_idx_problem[..., 1], vol_idx_problem[..., 2]] = problem_config
134
+ vol_idx_problem_adj = vol_idx_problem + problem_config[..., 1:4]
135
+
136
+ within_range = (
137
+ vol_idx_problem_adj[..., 0] >= 0) & (
138
+ vol_idx_problem_adj[..., 0] < res[0]) & (
139
+ vol_idx_problem_adj[..., 1] >= 0) & (
140
+ vol_idx_problem_adj[..., 1] < res[1]) & (
141
+ vol_idx_problem_adj[..., 2] >= 0) & (
142
+ vol_idx_problem_adj[..., 2] < res[2])
143
+
144
+ vol_idx_problem = vol_idx_problem[within_range]
145
+ vol_idx_problem_adj = vol_idx_problem_adj[within_range]
146
+ problem_config = problem_config[within_range]
147
+ problem_config_adj = problem_config_full[vol_idx_problem_adj[..., 0],
148
+ vol_idx_problem_adj[..., 1], vol_idx_problem_adj[..., 2]]
149
+ # If two cubes with cases C16 and C19 share an ambiguous face, both cases are inverted.
150
+ to_invert = (problem_config_adj[..., 0] == 1)
151
+ idx = torch.arange(case_ids.shape[0], device=self.device)[to_check][within_range][to_invert]
152
+ case_ids.index_put_((idx,), problem_config[to_invert][..., -1])
153
+ return case_ids
154
+
155
+ @torch.no_grad()
156
+ def _identify_surf_edges(self, scalar_field, cube_idx, surf_cubes):
157
+ """
158
+ Identifies grid edges that intersect with the underlying surface by checking for opposite signs. As each edge
159
+ can be shared by multiple cubes, this function also assigns a unique index to each surface-intersecting edge
160
+ and marks the cube edges with this index.
161
+ """
162
+ occ_n = scalar_field < 0
163
+ all_edges = cube_idx[surf_cubes][:, self.cube_edges].reshape(-1, 2)
164
+ unique_edges, _idx_map, counts = torch.unique(all_edges, dim=0, return_inverse=True, return_counts=True)
165
+
166
+ unique_edges = unique_edges.long()
167
+ mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1
168
+
169
+ surf_edges_mask = mask_edges[_idx_map]
170
+ counts = counts[_idx_map]
171
+
172
+ mapping = torch.ones((unique_edges.shape[0]), dtype=torch.long, device=cube_idx.device) * -1
173
+ mapping[mask_edges] = torch.arange(mask_edges.sum(), device=cube_idx.device)
174
+ # Shaped as [number of cubes x 12 edges per cube]. This is later used to map a cube edge to the unique index
175
+ # for a surface-intersecting edge. Non-surface-intersecting edges are marked with -1.
176
+ idx_map = mapping[_idx_map]
177
+ surf_edges = unique_edges[mask_edges]
178
+ return surf_edges, idx_map, counts, surf_edges_mask
179
+
180
+ @torch.no_grad()
181
+ def _identify_surf_cubes(self, scalar_field, cube_idx):
182
+ """
183
+ Identifies grid cubes that intersect with the underlying surface by checking if the signs at
184
+ all corners are not identical.
185
+ """
186
+ occ_n = scalar_field < 0
187
+ occ_fx8 = occ_n[cube_idx.reshape(-1)].reshape(-1, 8)
188
+ _occ_sum = torch.sum(occ_fx8, -1)
189
+ surf_cubes = (_occ_sum > 0) & (_occ_sum < 8)
190
+ return surf_cubes, occ_fx8
191
+
192
+ def _linear_interp(self, edges_weight, edges_x):
193
+ """
194
+ Computes the location of zero-crossings on 'edges_x' using linear interpolation with 'edges_weight'.
195
+ """
196
+ edge_dim = edges_weight.dim() - 2
197
+ assert edges_weight.shape[edge_dim] == 2
198
+ edges_weight = torch.cat([torch.index_select(input=edges_weight, index=torch.tensor(1, device=self.device), dim=edge_dim), -
199
+ torch.index_select(input=edges_weight, index=torch.tensor(0, device=self.device), dim=edge_dim)]
200
+ , edge_dim)
201
+ denominator = edges_weight.sum(edge_dim)
202
+ ue = (edges_x * edges_weight).sum(edge_dim) / denominator
203
+ return ue
204
+
205
+ def _solve_vd_QEF(self, p_bxnx3, norm_bxnx3, c_bx3, qef_reg_scale):
206
+ p_bxnx3 = p_bxnx3.reshape(-1, 7, 3)
207
+ norm_bxnx3 = norm_bxnx3.reshape(-1, 7, 3)
208
+ c_bx3 = c_bx3.reshape(-1, 3)
209
+ A = norm_bxnx3
210
+ B = ((p_bxnx3) * norm_bxnx3).sum(-1, keepdims=True)
211
+
212
+ A_reg = (torch.eye(3, device=p_bxnx3.device) * qef_reg_scale).unsqueeze(0).repeat(p_bxnx3.shape[0], 1, 1)
213
+ B_reg = (qef_reg_scale * c_bx3).unsqueeze(-1)
214
+ A = torch.cat([A, A_reg], 1)
215
+ B = torch.cat([B, B_reg], 1)
216
+ dual_verts = torch.linalg.lstsq(A, B).solution.squeeze(-1)
217
+ return dual_verts
218
+
219
+ def _compute_vd(self, voxelgrid_vertices, surf_cubes_fx8, surf_edges, scalar_field,
220
+ case_ids, beta, alpha, gamma_f, idx_map, qef_reg_scale, voxelgrid_colors):
221
+ """
222
+ Computes the location of dual vertices as described in Section 4.2
223
+ """
224
+ alpha_nx12x2 = torch.index_select(input=alpha, index=self.cube_edges, dim=1).reshape(-1, 12, 2)
225
+ surf_edges_x = torch.index_select(input=voxelgrid_vertices, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 3)
226
+ surf_edges_s = torch.index_select(input=scalar_field, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, 1)
227
+ zero_crossing = self._linear_interp(surf_edges_s, surf_edges_x)
228
+
229
+ if voxelgrid_colors is not None:
230
+ C = voxelgrid_colors.shape[-1]
231
+ surf_edges_c = torch.index_select(input=voxelgrid_colors, index=surf_edges.reshape(-1), dim=0).reshape(-1, 2, C)
232
+
233
+ idx_map = idx_map.reshape(-1, 12)
234
+ num_vd = torch.index_select(input=self.num_vd_table, index=case_ids, dim=0)
235
+ edge_group, edge_group_to_vd, edge_group_to_cube, vd_num_edges, vd_gamma = [], [], [], [], []
236
+
237
+ # if color is not None:
238
+ # vd_color = []
239
+
240
+ total_num_vd = 0
241
+ vd_idx_map = torch.zeros((case_ids.shape[0], 12), dtype=torch.long, device=self.device, requires_grad=False)
242
+
243
+ for num in torch.unique(num_vd):
244
+ cur_cubes = (num_vd == num) # consider cubes with the same numbers of vd emitted (for batching)
245
+ curr_num_vd = cur_cubes.sum() * num
246
+ curr_edge_group = self.dmc_table[case_ids[cur_cubes], :num].reshape(-1, num * 7)
247
+ curr_edge_group_to_vd = torch.arange(
248
+ curr_num_vd, device=self.device).unsqueeze(-1).repeat(1, 7) + total_num_vd
249
+ total_num_vd += curr_num_vd
250
+ curr_edge_group_to_cube = torch.arange(idx_map.shape[0], device=self.device)[
251
+ cur_cubes].unsqueeze(-1).repeat(1, num * 7).reshape_as(curr_edge_group)
252
+
253
+ curr_mask = (curr_edge_group != -1)
254
+ edge_group.append(torch.masked_select(curr_edge_group, curr_mask))
255
+ edge_group_to_vd.append(torch.masked_select(curr_edge_group_to_vd.reshape_as(curr_edge_group), curr_mask))
256
+ edge_group_to_cube.append(torch.masked_select(curr_edge_group_to_cube, curr_mask))
257
+ vd_num_edges.append(curr_mask.reshape(-1, 7).sum(-1, keepdims=True))
258
+ vd_gamma.append(torch.masked_select(gamma_f, cur_cubes).unsqueeze(-1).repeat(1, num).reshape(-1))
259
+ # if color is not None:
260
+ # vd_color.append(color[cur_cubes].unsqueeze(1).repeat(1, num, 1).reshape(-1, 3))
261
+
262
+ edge_group = torch.cat(edge_group)
263
+ edge_group_to_vd = torch.cat(edge_group_to_vd)
264
+ edge_group_to_cube = torch.cat(edge_group_to_cube)
265
+ vd_num_edges = torch.cat(vd_num_edges)
266
+ vd_gamma = torch.cat(vd_gamma)
267
+ # if color is not None:
268
+ # vd_color = torch.cat(vd_color)
269
+ # else:
270
+ # vd_color = None
271
+
272
+ vd = torch.zeros((total_num_vd, 3), device=self.device)
273
+ beta_sum = torch.zeros((total_num_vd, 1), device=self.device)
274
+
275
+ idx_group = torch.gather(input=idx_map.reshape(-1), dim=0, index=edge_group_to_cube * 12 + edge_group)
276
+
277
+ x_group = torch.index_select(input=surf_edges_x, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 3)
278
+ s_group = torch.index_select(input=surf_edges_s, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, 1)
279
+
280
+
281
+ zero_crossing_group = torch.index_select(
282
+ input=zero_crossing, index=idx_group.reshape(-1), dim=0).reshape(-1, 3)
283
+
284
+ alpha_group = torch.index_select(input=alpha_nx12x2.reshape(-1, 2), dim=0,
285
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 2, 1)
286
+ ue_group = self._linear_interp(s_group * alpha_group, x_group)
287
+
288
+ beta_group = torch.gather(input=beta.reshape(-1), dim=0,
289
+ index=edge_group_to_cube * 12 + edge_group).reshape(-1, 1)
290
+ beta_sum = beta_sum.index_add_(0, index=edge_group_to_vd, source=beta_group)
291
+ vd = vd.index_add_(0, index=edge_group_to_vd, source=ue_group * beta_group) / beta_sum
292
+
293
+ '''
294
+ interpolate colors use the same method as dual vertices
295
+ '''
296
+ if voxelgrid_colors is not None:
297
+ vd_color = torch.zeros((total_num_vd, C), device=self.device)
298
+ c_group = torch.index_select(input=surf_edges_c, index=idx_group.reshape(-1), dim=0).reshape(-1, 2, C)
299
+ uc_group = self._linear_interp(s_group * alpha_group, c_group)
300
+ vd_color = vd_color.index_add_(0, index=edge_group_to_vd, source=uc_group * beta_group) / beta_sum
301
+ else:
302
+ vd_color = None
303
+
304
+ L_dev = self._compute_reg_loss(vd, zero_crossing_group, edge_group_to_vd, vd_num_edges)
305
+
306
+ v_idx = torch.arange(vd.shape[0], device=self.device) # + total_num_vd
307
+
308
+ vd_idx_map = (vd_idx_map.reshape(-1)).scatter(dim=0, index=edge_group_to_cube *
309
+ 12 + edge_group, src=v_idx[edge_group_to_vd])
310
+
311
+ return vd, L_dev, vd_gamma, vd_idx_map, vd_color
312
+
313
+ def _triangulate(self, scalar_field, surf_edges, vd, vd_gamma, edge_counts, idx_map, vd_idx_map, surf_edges_mask, training, vd_color):
314
+ """
315
+ Connects four neighboring dual vertices to form a quadrilateral. The quadrilaterals are then split into
316
+ triangles based on the gamma parameter, as described in Section 4.3.
317
+ """
318
+ with torch.no_grad():
319
+ group_mask = (edge_counts == 4) & surf_edges_mask # surface edges shared by 4 cubes.
320
+ group = idx_map.reshape(-1)[group_mask]
321
+ vd_idx = vd_idx_map[group_mask]
322
+ edge_indices, indices = torch.sort(group, stable=True)
323
+ quad_vd_idx = vd_idx[indices].reshape(-1, 4)
324
+
325
+ # Ensure all face directions point towards the positive SDF to maintain consistent winding.
326
+ s_edges = scalar_field[surf_edges[edge_indices.reshape(-1, 4)[:, 0]].reshape(-1)].reshape(-1, 2)
327
+ flip_mask = s_edges[:, 0] > 0
328
+ quad_vd_idx = torch.cat((quad_vd_idx[flip_mask][:, [0, 1, 3, 2]],
329
+ quad_vd_idx[~flip_mask][:, [2, 3, 1, 0]]))
330
+
331
+ quad_gamma = torch.index_select(input=vd_gamma, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4)
332
+ gamma_02 = quad_gamma[:, 0] * quad_gamma[:, 2]
333
+ gamma_13 = quad_gamma[:, 1] * quad_gamma[:, 3]
334
+ if not training:
335
+ mask = (gamma_02 > gamma_13)
336
+ faces = torch.zeros((quad_gamma.shape[0], 6), dtype=torch.long, device=quad_vd_idx.device)
337
+ faces[mask] = quad_vd_idx[mask][:, self.quad_split_1]
338
+ faces[~mask] = quad_vd_idx[~mask][:, self.quad_split_2]
339
+ faces = faces.reshape(-1, 3)
340
+ else:
341
+ vd_quad = torch.index_select(input=vd, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, 3)
342
+ vd_02 = (vd_quad[:, 0] + vd_quad[:, 2]) / 2
343
+ vd_13 = (vd_quad[:, 1] + vd_quad[:, 3]) / 2
344
+ weight_sum = (gamma_02 + gamma_13) + 1e-8
345
+ vd_center = (vd_02 * gamma_02.unsqueeze(-1) + vd_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
346
+
347
+ if vd_color is not None:
348
+ color_quad = torch.index_select(input=vd_color, index=quad_vd_idx.reshape(-1), dim=0).reshape(-1, 4, vd_color.shape[-1])
349
+ color_02 = (color_quad[:, 0] + color_quad[:, 2]) / 2
350
+ color_13 = (color_quad[:, 1] + color_quad[:, 3]) / 2
351
+ color_center = (color_02 * gamma_02.unsqueeze(-1) + color_13 * gamma_13.unsqueeze(-1)) / weight_sum.unsqueeze(-1)
352
+ vd_color = torch.cat([vd_color, color_center])
353
+
354
+
355
+ vd_center_idx = torch.arange(vd_center.shape[0], device=self.device) + vd.shape[0]
356
+ vd = torch.cat([vd, vd_center])
357
+ faces = quad_vd_idx[:, self.quad_split_train].reshape(-1, 4, 2)
358
+ faces = torch.cat([faces, vd_center_idx.reshape(-1, 1, 1).repeat(1, 4, 1)], -1).reshape(-1, 3)
359
+ return vd, faces, s_edges, edge_indices, vd_color