ZeqiangLai commited on
Commit
6b88e7a
·
verified ·
1 Parent(s): f2802bb

Upload 16 files

Browse files
hy3dgen/shapegen/__init__.py CHANGED
@@ -23,5 +23,5 @@
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
 
25
  from .pipelines import Hunyuan3DDiTPipeline, Hunyuan3DDiTFlowMatchingPipeline
26
- from .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover, MeshSimplifier
27
  from .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR
 
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
 
25
  from .pipelines import Hunyuan3DDiTPipeline, Hunyuan3DDiTFlowMatchingPipeline
26
+ from .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover
27
  from .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR
hy3dgen/shapegen/models/__init__.py CHANGED
@@ -25,4 +25,4 @@
25
 
26
  from .autoencoders import ShapeVAE
27
  from .conditioner import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
28
- from .denoisers import Hunyuan3DDiT
 
25
 
26
  from .autoencoders import ShapeVAE
27
  from .conditioner import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
28
+ from .denoisers import HunYuanDiTPlain, Hunyuan3DDiT
hy3dgen/shapegen/models/autoencoders/__init__.py CHANGED
@@ -1,5 +1,5 @@
1
  from .attention_blocks import CrossAttentionDecoder
2
- from .attention_processors import CrossAttentionProcessor
3
  from .model import ShapeVAE, VectsetVAE
4
- from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor, Latent2MeshOutput
5
- from .volume_decoders import VanillaVolumeDecoder
 
1
  from .attention_blocks import CrossAttentionDecoder
2
+ from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
3
  from .model import ShapeVAE, VectsetVAE
4
+ from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor
5
+ from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder
hy3dgen/shapegen/models/autoencoders/attention_processors.py CHANGED
@@ -17,3 +17,95 @@ class CrossAttentionProcessor:
17
  out = scaled_dot_product_attention(q, k, v)
18
  return out
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
17
  out = scaled_dot_product_attention(q, k, v)
18
  return out
19
 
20
+
21
+ class FlashVDMCrossAttentionProcessor:
22
+ def __init__(self, topk=None):
23
+ self.topk = topk
24
+
25
+ def __call__(self, attn, q, k, v):
26
+ if k.shape[-2] == 3072:
27
+ topk = 1024
28
+ elif k.shape[-2] == 512:
29
+ topk = 256
30
+ else:
31
+ topk = k.shape[-2] // 3
32
+
33
+ if self.topk is True:
34
+ q1 = q[:, :, ::100, :]
35
+ sim = q1 @ k.transpose(-1, -2)
36
+ sim = torch.mean(sim, -2)
37
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
38
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
39
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
40
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
41
+ out = scaled_dot_product_attention(q, k0, v0)
42
+ elif self.topk is False:
43
+ out = scaled_dot_product_attention(q, k, v)
44
+ else:
45
+ idx, counts = self.topk
46
+ start = 0
47
+ outs = []
48
+ for grid_coord, count in zip(idx, counts):
49
+ end = start + count
50
+ q_chunk = q[:, :, start:end, :]
51
+ q1 = q_chunk[:, :, ::50, :]
52
+ sim = q1 @ k.transpose(-1, -2)
53
+ sim = torch.mean(sim, -2)
54
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
55
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
56
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
57
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
58
+ out = scaled_dot_product_attention(q_chunk, k0, v0)
59
+ outs.append(out)
60
+ start += count
61
+ out = torch.cat(outs, dim=-2)
62
+ self.topk = False
63
+ return out
64
+
65
+
66
+ class FlashVDMTopMCrossAttentionProcessor:
67
+ def __init__(self, topk=None):
68
+ self.topk = topk
69
+
70
+ def __call__(self, attn, q, k, v):
71
+ if k.shape[-2] == 3072:
72
+ topk = 1024
73
+ elif k.shape[-2] == 512:
74
+ topk = 256
75
+ else:
76
+ topk = k.shape[-2] // 3
77
+
78
+ if self.topk is True:
79
+ q1 = q[:, :, ::100, :]
80
+ sim = q1 @ k.transpose(-1, -2)
81
+ sim = torch.mean(sim, -2)
82
+ topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
83
+ topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
84
+ v0 = torch.gather(v, dim=-2, index=topk_ind)
85
+ k0 = torch.gather(k, dim=-2, index=topk_ind)
86
+ out = scaled_dot_product_attention(q, k0, v0)
87
+ elif self.topk is False:
88
+ out = scaled_dot_product_attention(q, k, v)
89
+ else:
90
+ idx, counts = self.topk
91
+ start = 0
92
+ outs = []
93
+ for grid_coord, count in zip(idx, counts):
94
+ end = start + count
95
+ q_chunk = q[:, :, start:end, :]
96
+ q1 = q_chunk[:, :, ::30, :]
97
+ sim = q1 @ k.transpose(-1, -2)
98
+ # sim = sim.to(torch.float32)
99
+ sim = sim.softmax(-1)
100
+ sim = torch.mean(sim, 1)
101
+ activated_token = torch.where(sim > 1e-6)[2]
102
+ index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
103
+ index = index.expand(-1, v.shape[1], -1, v.shape[-1])
104
+ v0 = torch.gather(v, dim=-2, index=index)
105
+ k0 = torch.gather(k, dim=-2, index=index)
106
+ out = scaled_dot_product_attention(q_chunk, k0, v0) # bhnc
107
+ outs.append(out)
108
+ start += count
109
+ out = torch.cat(outs, dim=-2)
110
+ self.topk = False
111
+ return out
hy3dgen/shapegen/models/autoencoders/model.py CHANGED
@@ -6,7 +6,7 @@ import yaml
6
 
7
  from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder
8
  from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
9
- from .volume_decoders import VanillaVolumeDecoder
10
  from ...utils import logger, synchronize_timer
11
 
12
 
@@ -117,6 +117,25 @@ class VectsetVAE(nn.Module):
117
  outputs = self.surface_extractor(grid_logits, **kwargs)
118
  return outputs
119
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
120
 
121
  class ShapeVAE(VectsetVAE):
122
  def __init__(
 
6
 
7
  from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder
8
  from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
9
+ from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding
10
  from ...utils import logger, synchronize_timer
11
 
12
 
 
117
  outputs = self.surface_extractor(grid_logits, **kwargs)
118
  return outputs
119
 
120
+ def enable_flashvdm_decoder(
121
+ self,
122
+ enabled: bool = True,
123
+ adaptive_kv_selection=True,
124
+ topk_mode='mean',
125
+ mc_algo='dmc',
126
+ ):
127
+ if enabled:
128
+ if adaptive_kv_selection:
129
+ self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)
130
+ else:
131
+ self.volume_decoder = HierarchicalVolumeDecoding()
132
+ if mc_algo not in SurfaceExtractors.keys():
133
+ raise ValueError(f'Unsupported mc_algo {mc_algo}, available: {list(SurfaceExtractors.keys())}')
134
+ self.surface_extractor = SurfaceExtractors[mc_algo]()
135
+ else:
136
+ self.volume_decoder = VanillaVolumeDecoder()
137
+ self.surface_extractor = MCSurfaceExtractor()
138
+
139
 
140
  class ShapeVAE(VectsetVAE):
141
  def __init__(
hy3dgen/shapegen/models/autoencoders/volume_decoders.py CHANGED
@@ -8,9 +8,111 @@ from einops import repeat
8
  from tqdm import tqdm
9
 
10
  from .attention_blocks import CrossAttentionDecoder
 
11
  from ...utils import logger
12
 
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
14
  def generate_dense_grid_points(
15
  bbox_min: np.ndarray,
16
  bbox_max: np.ndarray,
@@ -74,3 +176,254 @@ class VanillaVolumeDecoder:
74
  return grid_logits
75
 
76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
  from tqdm import tqdm
9
 
10
  from .attention_blocks import CrossAttentionDecoder
11
+ from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
12
  from ...utils import logger
13
 
14
 
15
+ def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
16
+ """
17
+ 修复维度问题的PyTorch实现
18
+ Args:
19
+ input_tensor: shape [D, D, D], torch.float16
20
+ alpha: 标量偏移值
21
+ Returns:
22
+ mask: shape [D, D, D], torch.int32 表面掩码
23
+ """
24
+ device = input_tensor.device
25
+ D = input_tensor.shape[0]
26
+ signed_val = 0.0
27
+
28
+ # 添加偏移并处理无效值
29
+ val = input_tensor + alpha
30
+ valid_mask = val > -9000 # 假设-9000是无效值
31
+
32
+ # 改进的邻居获取函数(保持维度一致)
33
+ def get_neighbor(t, shift, axis):
34
+ """根据指定轴进行位移并保持维度一致"""
35
+ if shift == 0:
36
+ return t.clone()
37
+
38
+ # 确定填充轴(输入为[D, D, D]对应z,y,x轴)
39
+ pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后]
40
+
41
+ # 根据轴类型设置填充
42
+ if axis == 0: # x轴(最后一个维度)
43
+ pad_idx = 0 if shift > 0 else 1
44
+ pad_dims[pad_idx] = abs(shift)
45
+ elif axis == 1: # y轴(中间维度)
46
+ pad_idx = 2 if shift > 0 else 3
47
+ pad_dims[pad_idx] = abs(shift)
48
+ elif axis == 2: # z轴(第一个维度)
49
+ pad_idx = 4 if shift > 0 else 5
50
+ pad_dims[pad_idx] = abs(shift)
51
+
52
+ # 执行填充(添加batch和channel维度适配F.pad)
53
+ padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate') # 反转顺序适配F.pad
54
+
55
+ # 构建动态切片索引
56
+ slice_dims = [slice(None)] * 3 # 初始化为全切片
57
+ if axis == 0: # x轴(dim=2)
58
+ if shift > 0:
59
+ slice_dims[0] = slice(shift, None)
60
+ else:
61
+ slice_dims[0] = slice(None, shift)
62
+ elif axis == 1: # y轴(dim=1)
63
+ if shift > 0:
64
+ slice_dims[1] = slice(shift, None)
65
+ else:
66
+ slice_dims[1] = slice(None, shift)
67
+ elif axis == 2: # z轴(dim=0)
68
+ if shift > 0:
69
+ slice_dims[2] = slice(shift, None)
70
+ else:
71
+ slice_dims[2] = slice(None, shift)
72
+
73
+ # 应用切片并恢复维度
74
+ padded = padded.squeeze(0).squeeze(0)
75
+ sliced = padded[slice_dims]
76
+ return sliced
77
+
78
+ # 获取各方向邻居(确保维度一致)
79
+ left = get_neighbor(val, 1, axis=0) # x方向
80
+ right = get_neighbor(val, -1, axis=0)
81
+ back = get_neighbor(val, 1, axis=1) # y方向
82
+ front = get_neighbor(val, -1, axis=1)
83
+ down = get_neighbor(val, 1, axis=2) # z方向
84
+ up = get_neighbor(val, -1, axis=2)
85
+
86
+ # 处理边界无效值(使用where保持维度一致)
87
+ def safe_where(neighbor):
88
+ return torch.where(neighbor > -9000, neighbor, val)
89
+
90
+ left = safe_where(left)
91
+ right = safe_where(right)
92
+ back = safe_where(back)
93
+ front = safe_where(front)
94
+ down = safe_where(down)
95
+ up = safe_where(up)
96
+
97
+ # 计算符号一致性(转换为float32确保精度)
98
+ sign = torch.sign(val.to(torch.float32))
99
+ neighbors_sign = torch.stack([
100
+ torch.sign(left.to(torch.float32)),
101
+ torch.sign(right.to(torch.float32)),
102
+ torch.sign(back.to(torch.float32)),
103
+ torch.sign(front.to(torch.float32)),
104
+ torch.sign(down.to(torch.float32)),
105
+ torch.sign(up.to(torch.float32))
106
+ ], dim=0)
107
+
108
+ # 检查所有符号是否一致
109
+ same_sign = torch.all(neighbors_sign == sign, dim=0)
110
+
111
+ # 生成最终掩码
112
+ mask = (~same_sign).to(torch.int32)
113
+ return mask * valid_mask.to(torch.int32)
114
+
115
+
116
  def generate_dense_grid_points(
117
  bbox_min: np.ndarray,
118
  bbox_max: np.ndarray,
 
176
  return grid_logits
177
 
178
 
179
+ class HierarchicalVolumeDecoding:
180
+ @torch.no_grad()
181
+ def __call__(
182
+ self,
183
+ latents: torch.FloatTensor,
184
+ geo_decoder: Callable,
185
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
186
+ num_chunks: int = 10000,
187
+ mc_level: float = 0.0,
188
+ octree_resolution: int = None,
189
+ min_resolution: int = 63,
190
+ enable_pbar: bool = True,
191
+ **kwargs,
192
+ ):
193
+ device = latents.device
194
+ dtype = latents.dtype
195
+
196
+ resolutions = []
197
+ if octree_resolution < min_resolution:
198
+ resolutions.append(octree_resolution)
199
+ while octree_resolution >= min_resolution:
200
+ resolutions.append(octree_resolution)
201
+ octree_resolution = octree_resolution // 2
202
+ resolutions.reverse()
203
+
204
+ # 1. generate query points
205
+ if isinstance(bounds, float):
206
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
207
+ bbox_min = np.array(bounds[0:3])
208
+ bbox_max = np.array(bounds[3:6])
209
+ bbox_size = bbox_max - bbox_min
210
+
211
+ xyz_samples, grid_size, length = generate_dense_grid_points(
212
+ bbox_min=bbox_min,
213
+ bbox_max=bbox_max,
214
+ octree_resolution=resolutions[0],
215
+ indexing="ij"
216
+ )
217
+
218
+ dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
219
+ dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
220
+
221
+ grid_size = np.array(grid_size)
222
+ xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
223
+
224
+ # 2. latents to 3d volume
225
+ batch_logits = []
226
+ batch_size = latents.shape[0]
227
+ for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
228
+ desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"):
229
+ queries = xyz_samples[start: start + num_chunks, :]
230
+ batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
231
+ logits = geo_decoder(queries=batch_queries, latents=latents)
232
+ batch_logits.append(logits)
233
+
234
+ grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
235
+
236
+ for octree_depth_now in resolutions[1:]:
237
+ grid_size = np.array([octree_depth_now + 1] * 3)
238
+ resolution = bbox_size / octree_depth_now
239
+ next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
240
+ next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
241
+ curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
242
+ curr_points += grid_logits.squeeze(0).abs() < 0.95
243
+
244
+ if octree_depth_now == resolutions[-1]:
245
+ expand_num = 0
246
+ else:
247
+ expand_num = 1
248
+ for i in range(expand_num):
249
+ curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
250
+ (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
251
+ next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
252
+ for i in range(2 - expand_num):
253
+ next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
254
+ nidx = torch.where(next_index > 0)
255
+
256
+ next_points = torch.stack(nidx, dim=1)
257
+ next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +
258
+ torch.tensor(bbox_min, dtype=next_points.dtype, device=device))
259
+ batch_logits = []
260
+ for start in tqdm(range(0, next_points.shape[0], num_chunks),
261
+ desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"):
262
+ queries = next_points[start: start + num_chunks, :]
263
+ batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
264
+ logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)
265
+ batch_logits.append(logits)
266
+ grid_logits = torch.cat(batch_logits, dim=1)
267
+ next_logits[nidx] = grid_logits[0, ..., 0]
268
+ grid_logits = next_logits.unsqueeze(0)
269
+ grid_logits[grid_logits == -10000.] = float('nan')
270
+
271
+ return grid_logits
272
+
273
+
274
+ class FlashVDMVolumeDecoding:
275
+ def __init__(self, topk_mode='mean'):
276
+ if topk_mode not in ['mean', 'merge']:
277
+ raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}')
278
+
279
+ if topk_mode == 'mean':
280
+ self.processor = FlashVDMCrossAttentionProcessor()
281
+ else:
282
+ self.processor = FlashVDMTopMCrossAttentionProcessor()
283
+
284
+ @torch.no_grad()
285
+ def __call__(
286
+ self,
287
+ latents: torch.FloatTensor,
288
+ geo_decoder: CrossAttentionDecoder,
289
+ bounds: Union[Tuple[float], List[float], float] = 1.01,
290
+ num_chunks: int = 10000,
291
+ mc_level: float = 0.0,
292
+ octree_resolution: int = None,
293
+ min_resolution: int = 63,
294
+ mini_grid_num: int = 4,
295
+ enable_pbar: bool = True,
296
+ **kwargs,
297
+ ):
298
+ processor = self.processor
299
+ geo_decoder.set_cross_attention_processor(processor)
300
+
301
+ device = latents.device
302
+ dtype = latents.dtype
303
+
304
+ resolutions = []
305
+ if octree_resolution < min_resolution:
306
+ resolutions.append(octree_resolution)
307
+ while octree_resolution >= min_resolution:
308
+ resolutions.append(octree_resolution)
309
+ octree_resolution = octree_resolution // 2
310
+ resolutions.reverse()
311
+ resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
312
+ for i, resolution in enumerate(resolutions[1:]):
313
+ resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
314
+
315
+ logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
316
+
317
+ # 1. generate query points
318
+ if isinstance(bounds, float):
319
+ bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
320
+ bbox_min = np.array(bounds[0:3])
321
+ bbox_max = np.array(bounds[3:6])
322
+ bbox_size = bbox_max - bbox_min
323
+
324
+ xyz_samples, grid_size, length = generate_dense_grid_points(
325
+ bbox_min=bbox_min,
326
+ bbox_max=bbox_max,
327
+ octree_resolution=resolutions[0],
328
+ indexing="ij"
329
+ )
330
+
331
+ dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
332
+ dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
333
+
334
+ grid_size = np.array(grid_size)
335
+
336
+ # 2. latents to 3d volume
337
+ xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
338
+ batch_size = latents.shape[0]
339
+ mini_grid_size = xyz_samples.shape[0] // mini_grid_num
340
+ xyz_samples = xyz_samples.view(
341
+ mini_grid_num, mini_grid_size,
342
+ mini_grid_num, mini_grid_size,
343
+ mini_grid_num, mini_grid_size, 3
344
+ ).permute(
345
+ 0, 2, 4, 1, 3, 5, 6
346
+ ).reshape(
347
+ -1, mini_grid_size * mini_grid_size * mini_grid_size, 3
348
+ )
349
+ batch_logits = []
350
+ num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
351
+ for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
352
+ desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
353
+ queries = xyz_samples[start: start + num_batchs, :]
354
+ batch = queries.shape[0]
355
+ batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
356
+ processor.topk = True
357
+ logits = geo_decoder(queries=queries, latents=batch_latents)
358
+ batch_logits.append(logits)
359
+ grid_logits = torch.cat(batch_logits, dim=0).reshape(
360
+ mini_grid_num, mini_grid_num, mini_grid_num,
361
+ mini_grid_size, mini_grid_size,
362
+ mini_grid_size
363
+ ).permute(0, 3, 1, 4, 2, 5).contiguous().view(
364
+ (batch_size, grid_size[0], grid_size[1], grid_size[2])
365
+ )
366
+
367
+ for octree_depth_now in resolutions[1:]:
368
+ grid_size = np.array([octree_depth_now + 1] * 3)
369
+ resolution = bbox_size / octree_depth_now
370
+ next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
371
+ next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
372
+ curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
373
+ curr_points += grid_logits.squeeze(0).abs() < 0.95
374
+
375
+ if octree_depth_now == resolutions[-1]:
376
+ expand_num = 0
377
+ else:
378
+ expand_num = 1
379
+ for i in range(expand_num):
380
+ curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
381
+ (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
382
+
383
+ next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
384
+ for i in range(2 - expand_num):
385
+ next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
386
+ nidx = torch.where(next_index > 0)
387
+
388
+ next_points = torch.stack(nidx, dim=1)
389
+ next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
390
+ torch.tensor(bbox_min, dtype=torch.float32, device=device))
391
+
392
+ query_grid_num = 6
393
+ min_val = next_points.min(axis=0).values
394
+ max_val = next_points.max(axis=0).values
395
+ vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
396
+ index = torch.floor(vol_queries_index).long()
397
+ index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
398
+ index = index.sort()
399
+ next_points = next_points[index.indices].unsqueeze(0).contiguous()
400
+ unique_values = torch.unique(index.values, return_counts=True)
401
+ grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
402
+ input_grid = [[], []]
403
+ logits_grid_list = []
404
+ start_num = 0
405
+ sum_num = 0
406
+ for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
407
+ if sum_num + count < num_chunks or sum_num == 0:
408
+ sum_num += count
409
+ input_grid[0].append(grid_index)
410
+ input_grid[1].append(count)
411
+ else:
412
+ processor.topk = input_grid
413
+ logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
414
+ start_num = start_num + sum_num
415
+ logits_grid_list.append(logits_grid)
416
+ input_grid = [[grid_index], [count]]
417
+ sum_num = count
418
+ if sum_num > 0:
419
+ processor.topk = input_grid
420
+ logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
421
+ logits_grid_list.append(logits_grid)
422
+ logits_grid = torch.cat(logits_grid_list, dim=1)
423
+ grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
424
+ next_logits[nidx] = grid_logits
425
+ grid_logits = next_logits.unsqueeze(0)
426
+
427
+ grid_logits[grid_logits == -10000.] = float('nan')
428
+
429
+ return grid_logits
hy3dgen/shapegen/models/conditioner.py CHANGED
@@ -22,7 +22,6 @@
22
  # fine-tuning enabling code and other elements of the foregoing made publicly available
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
 
25
- import numpy as np
26
  import torch
27
  import torch.nn as nn
28
  from torchvision import transforms
@@ -34,26 +33,6 @@ from transformers import (
34
  )
35
 
36
 
37
- def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
38
- """
39
- embed_dim: output dimension for each position
40
- pos: a list of positions to be encoded: size (M,)
41
- out: (M, D)
42
- """
43
- assert embed_dim % 2 == 0
44
- omega = np.arange(embed_dim // 2, dtype=np.float64)
45
- omega /= embed_dim / 2.
46
- omega = 1. / 10000 ** omega # (D/2,)
47
-
48
- pos = pos.reshape(-1) # (M,)
49
- out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
50
-
51
- emb_sin = np.sin(out) # (M, D/2)
52
- emb_cos = np.cos(out) # (M, D/2)
53
-
54
- return np.concatenate([emb_sin, emb_cos], axis=1)
55
-
56
-
57
  class ImageEncoder(nn.Module):
58
  def __init__(
59
  self,
@@ -88,7 +67,7 @@ class ImageEncoder(nn.Module):
88
  ]
89
  )
90
 
91
- def forward(self, image, mask=None, value_range=(-1, 1), **kwargs):
92
  if value_range is not None:
93
  low, high = value_range
94
  image = (image - low) / (high - low)
@@ -103,7 +82,7 @@ class ImageEncoder(nn.Module):
103
 
104
  return last_hidden_state
105
 
106
- def unconditional_embedding(self, batch_size, **kwargs):
107
  device = next(self.model.parameters()).device
108
  dtype = next(self.model.parameters()).dtype
109
  zero = torch.zeros(
@@ -131,82 +110,11 @@ class DinoImageEncoder(ImageEncoder):
131
  std = [0.229, 0.224, 0.225]
132
 
133
 
134
- class DinoImageEncoderMV(DinoImageEncoder):
135
- def __init__(
136
- self,
137
- version=None,
138
- config=None,
139
- use_cls_token=True,
140
- image_size=224,
141
- view_num=4,
142
- **kwargs,
143
- ):
144
- super().__init__(version, config, use_cls_token, image_size, **kwargs)
145
- self.view_num = view_num
146
- self.num_patches = self.num_patches
147
- pos = np.arange(self.view_num, dtype=np.float32)
148
- view_embedding = torch.from_numpy(
149
- get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
150
-
151
- view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
152
- self.view_embed = view_embedding.unsqueeze(0)
153
-
154
- def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None):
155
- if value_range is not None:
156
- low, high = value_range
157
- image = (image - low) / (high - low)
158
-
159
- image = image.to(self.model.device, dtype=self.model.dtype)
160
-
161
- bs, num_views, c, h, w = image.shape
162
- image = image.view(bs * num_views, c, h, w)
163
-
164
- inputs = self.transform(image)
165
- outputs = self.model(inputs)
166
-
167
- last_hidden_state = outputs.last_hidden_state
168
- last_hidden_state = last_hidden_state.view(
169
- bs, num_views, last_hidden_state.shape[-2],
170
- last_hidden_state.shape[-1]
171
- )
172
-
173
- view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
174
- if view_idxs is not None:
175
- assert len(view_idxs) == bs
176
- view_embeddings = []
177
- for i in range(bs):
178
- view_idx = view_idxs[i]
179
- assert num_views == len(view_idx)
180
- view_embeddings.append(self.view_embed[:, view_idx, ...])
181
- view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
182
-
183
- if num_views != self.view_num:
184
- view_embedding = view_embedding[:, :num_views, ...]
185
- last_hidden_state = last_hidden_state + view_embedding
186
- last_hidden_state = last_hidden_state.view(bs, num_views * last_hidden_state.shape[-2],
187
- last_hidden_state.shape[-1])
188
- return last_hidden_state
189
-
190
- def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
191
- device = next(self.model.parameters()).device
192
- dtype = next(self.model.parameters()).dtype
193
- zero = torch.zeros(
194
- batch_size,
195
- self.num_patches * len(view_idxs[0]),
196
- self.model.config.hidden_size,
197
- device=device,
198
- dtype=dtype,
199
- )
200
- return zero
201
-
202
-
203
  def build_image_encoder(config):
204
  if config['type'] == 'CLIPImageEncoder':
205
  return CLIPImageEncoder(**config['kwargs'])
206
  elif config['type'] == 'DinoImageEncoder':
207
  return DinoImageEncoder(**config['kwargs'])
208
- elif config['type'] == 'DinoImageEncoderMV':
209
- return DinoImageEncoderMV(**config['kwargs'])
210
  else:
211
  raise ValueError(f'Unknown image encoder type: {config["type"]}')
212
 
@@ -221,17 +129,17 @@ class DualImageEncoder(nn.Module):
221
  self.main_image_encoder = build_image_encoder(main_image_encoder)
222
  self.additional_image_encoder = build_image_encoder(additional_image_encoder)
223
 
224
- def forward(self, image, mask=None, **kwargs):
225
  outputs = {
226
- 'main': self.main_image_encoder(image, mask=mask, **kwargs),
227
- 'additional': self.additional_image_encoder(image, mask=mask, **kwargs),
228
  }
229
  return outputs
230
 
231
- def unconditional_embedding(self, batch_size, **kwargs):
232
  outputs = {
233
- 'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
234
- 'additional': self.additional_image_encoder.unconditional_embedding(batch_size, **kwargs),
235
  }
236
  return outputs
237
 
@@ -244,14 +152,14 @@ class SingleImageEncoder(nn.Module):
244
  super().__init__()
245
  self.main_image_encoder = build_image_encoder(main_image_encoder)
246
 
247
- def forward(self, image, mask=None, **kwargs):
248
  outputs = {
249
- 'main': self.main_image_encoder(image, mask=mask, **kwargs),
250
  }
251
  return outputs
252
 
253
- def unconditional_embedding(self, batch_size, **kwargs):
254
  outputs = {
255
- 'main': self.main_image_encoder.unconditional_embedding(batch_size, **kwargs),
256
  }
257
  return outputs
 
22
  # fine-tuning enabling code and other elements of the foregoing made publicly available
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
 
 
25
  import torch
26
  import torch.nn as nn
27
  from torchvision import transforms
 
33
  )
34
 
35
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
36
  class ImageEncoder(nn.Module):
37
  def __init__(
38
  self,
 
67
  ]
68
  )
69
 
70
+ def forward(self, image, mask=None, value_range=(-1, 1)):
71
  if value_range is not None:
72
  low, high = value_range
73
  image = (image - low) / (high - low)
 
82
 
83
  return last_hidden_state
84
 
85
+ def unconditional_embedding(self, batch_size):
86
  device = next(self.model.parameters()).device
87
  dtype = next(self.model.parameters()).dtype
88
  zero = torch.zeros(
 
110
  std = [0.229, 0.224, 0.225]
111
 
112
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
113
  def build_image_encoder(config):
114
  if config['type'] == 'CLIPImageEncoder':
115
  return CLIPImageEncoder(**config['kwargs'])
116
  elif config['type'] == 'DinoImageEncoder':
117
  return DinoImageEncoder(**config['kwargs'])
 
 
118
  else:
119
  raise ValueError(f'Unknown image encoder type: {config["type"]}')
120
 
 
129
  self.main_image_encoder = build_image_encoder(main_image_encoder)
130
  self.additional_image_encoder = build_image_encoder(additional_image_encoder)
131
 
132
+ def forward(self, image, mask=None):
133
  outputs = {
134
+ 'main': self.main_image_encoder(image, mask=mask),
135
+ 'additional': self.additional_image_encoder(image, mask=mask),
136
  }
137
  return outputs
138
 
139
+ def unconditional_embedding(self, batch_size):
140
  outputs = {
141
+ 'main': self.main_image_encoder.unconditional_embedding(batch_size),
142
+ 'additional': self.additional_image_encoder.unconditional_embedding(batch_size),
143
  }
144
  return outputs
145
 
 
152
  super().__init__()
153
  self.main_image_encoder = build_image_encoder(main_image_encoder)
154
 
155
+ def forward(self, image, mask=None):
156
  outputs = {
157
+ 'main': self.main_image_encoder(image, mask=mask),
158
  }
159
  return outputs
160
 
161
+ def unconditional_embedding(self, batch_size):
162
  outputs = {
163
+ 'main': self.main_image_encoder.unconditional_embedding(batch_size),
164
  }
165
  return outputs
hy3dgen/shapegen/models/denoisers/__init__.py CHANGED
@@ -1 +1,2 @@
1
  from .hunyuan3ddit import Hunyuan3DDiT
 
 
1
  from .hunyuan3ddit import Hunyuan3DDiT
2
+ from .hunyuandit import HunYuanDiTPlain
hy3dgen/shapegen/models/denoisers/hunyuan3ddit.py CHANGED
@@ -70,15 +70,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
70
  return embedding
71
 
72
 
73
- class GELU(nn.Module):
74
- def __init__(self, approximate='tanh'):
75
- super().__init__()
76
- self.approximate = approximate
77
-
78
- def forward(self, x: Tensor) -> Tensor:
79
- return nn.functional.gelu(x.contiguous(), approximate=self.approximate)
80
-
81
-
82
  class MLPEmbedder(nn.Module):
83
  def __init__(self, in_dim: int, hidden_dim: int):
84
  super().__init__()
@@ -181,7 +172,7 @@ class DoubleStreamBlock(nn.Module):
181
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
182
  self.img_mlp = nn.Sequential(
183
  nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
184
- GELU(approximate="tanh"),
185
  nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
186
  )
187
 
@@ -192,7 +183,7 @@ class DoubleStreamBlock(nn.Module):
192
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
193
  self.txt_mlp = nn.Sequential(
194
  nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
195
- GELU(approximate="tanh"),
196
  nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
197
  )
198
 
@@ -258,7 +249,7 @@ class SingleStreamBlock(nn.Module):
258
  self.hidden_size = hidden_size
259
  self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
260
 
261
- self.mlp_act = GELU(approximate="tanh")
262
  self.modulation = Modulation(hidden_size, double=False)
263
 
264
  def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
 
70
  return embedding
71
 
72
 
 
 
 
 
 
 
 
 
 
73
  class MLPEmbedder(nn.Module):
74
  def __init__(self, in_dim: int, hidden_dim: int):
75
  super().__init__()
 
172
  self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
173
  self.img_mlp = nn.Sequential(
174
  nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
175
+ nn.GELU(approximate="tanh"),
176
  nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
177
  )
178
 
 
183
  self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
184
  self.txt_mlp = nn.Sequential(
185
  nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
186
+ nn.GELU(approximate="tanh"),
187
  nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
188
  )
189
 
 
249
  self.hidden_size = hidden_size
250
  self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
251
 
252
+ self.mlp_act = nn.GELU(approximate="tanh")
253
  self.modulation = Modulation(hidden_size, double=False)
254
 
255
  def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
hy3dgen/shapegen/pipelines.py CHANGED
@@ -34,12 +34,11 @@ import trimesh
34
  import yaml
35
  from PIL import Image
36
  from diffusers.utils.torch_utils import randn_tensor
37
- from diffusers.utils.import_utils import is_accelerate_version, is_accelerate_available
38
  from tqdm import tqdm
39
 
40
  from .models.autoencoders import ShapeVAE
41
  from .models.autoencoders import SurfaceExtractors
42
- from .utils import logger, synchronize_timer, smart_load_model
43
 
44
 
45
  def retrieve_timesteps(
@@ -138,9 +137,6 @@ def instantiate_from_config(config, **kwargs):
138
 
139
 
140
  class Hunyuan3DDiTPipeline:
141
- model_cpu_offload_seq = "conditioner->model->vae"
142
- _exclude_from_cpu_offload = []
143
-
144
  @classmethod
145
  @synchronize_timer('Hunyuan3DDiTPipeline Model Loading')
146
  def from_single_file(
@@ -221,12 +217,34 @@ class Hunyuan3DDiTPipeline:
221
  dtype=dtype,
222
  device=device,
223
  )
224
- config_path, ckpt_path = smart_load_model(
225
- model_path,
226
- subfolder=subfolder,
227
- use_safetensors=use_safetensors,
228
- variant=variant
229
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
230
  return cls.from_single_file(
231
  ckpt_path,
232
  config_path,
@@ -271,18 +289,12 @@ class Hunyuan3DDiTPipeline:
271
  if enabled:
272
  model_path = self.kwargs['from_pretrained_kwargs']['model_path']
273
  turbo_vae_mapping = {
274
- 'Hunyuan3D-2': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0-turbo'),
275
- 'Hunyuan3D-2mv': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0-turbo'),
276
- 'Hunyuan3D-2mini': ('tencent/Hunyuan3D-2mini', 'hunyuan3d-vae-v2-mini-turbo'),
277
  }
278
  model_name = model_path.split('/')[-1]
279
  if replace_vae and model_name in turbo_vae_mapping:
280
- model_path, subfolder = turbo_vae_mapping[model_name]
281
- self.vae = ShapeVAE.from_pretrained(
282
- model_path, subfolder=subfolder,
283
- use_safetensors=self.kwargs['from_pretrained_kwargs']['use_safetensors'],
284
- device=self.device,
285
- )
286
  self.vae.enable_flashvdm_decoder(
287
  enabled=enabled,
288
  adaptive_kv_selection=adaptive_kv_selection,
@@ -292,146 +304,33 @@ class Hunyuan3DDiTPipeline:
292
  else:
293
  model_path = self.kwargs['from_pretrained_kwargs']['model_path']
294
  vae_mapping = {
295
- 'Hunyuan3D-2': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0'),
296
- 'Hunyuan3D-2mv': ('tencent/Hunyuan3D-2', 'hunyuan3d-vae-v2-0'),
297
- 'Hunyuan3D-2mini': ('tencent/Hunyuan3D-2mini', 'hunyuan3d-vae-v2-mini'),
298
  }
299
  model_name = model_path.split('/')[-1]
300
  if model_name in vae_mapping:
301
- model_path, subfolder = vae_mapping[model_name]
302
- self.vae = ShapeVAE.from_pretrained(model_path, subfolder=subfolder)
303
  self.vae.enable_flashvdm_decoder(enabled=False)
304
 
305
  def to(self, device=None, dtype=None):
306
- if dtype is not None:
307
- self.dtype = dtype
308
- self.vae.to(dtype=dtype)
309
- self.model.to(dtype=dtype)
310
- self.conditioner.to(dtype=dtype)
311
  if device is not None:
312
  self.device = torch.device(device)
313
  self.vae.to(device)
314
  self.model.to(device)
315
  self.conditioner.to(device)
316
-
317
- @property
318
- def _execution_device(self):
319
- r"""
320
- Returns the device on which the pipeline's models will be executed. After calling
321
- [`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
322
- Accelerate's module hooks.
323
- """
324
- for name, model in self.components.items():
325
- if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
326
- continue
327
-
328
- if not hasattr(model, "_hf_hook"):
329
- return self.device
330
- for module in model.modules():
331
- if (
332
- hasattr(module, "_hf_hook")
333
- and hasattr(module._hf_hook, "execution_device")
334
- and module._hf_hook.execution_device is not None
335
- ):
336
- return torch.device(module._hf_hook.execution_device)
337
- return self.device
338
-
339
- def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
340
- r"""
341
- Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
342
- to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
343
- method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
344
- `enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
345
-
346
- Arguments:
347
- gpu_id (`int`, *optional*):
348
- The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
349
- device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
350
- The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
351
- default to "cuda".
352
- """
353
- if self.model_cpu_offload_seq is None:
354
- raise ValueError(
355
- "Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
356
- )
357
-
358
- if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
359
- from accelerate import cpu_offload_with_hook
360
- else:
361
- raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
362
-
363
- torch_device = torch.device(device)
364
- device_index = torch_device.index
365
-
366
- if gpu_id is not None and device_index is not None:
367
- raise ValueError(
368
- f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
369
- f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
370
- )
371
-
372
- # _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
373
- self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
374
-
375
- device_type = torch_device.type
376
- device = torch.device(f"{device_type}:{self._offload_gpu_id}")
377
-
378
- if self.device.type != "cpu":
379
- self.to("cpu")
380
- device_mod = getattr(torch, self.device.type, None)
381
- if hasattr(device_mod, "empty_cache") and device_mod.is_available():
382
- device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
383
-
384
- all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
385
-
386
- self._all_hooks = []
387
- hook = None
388
- for model_str in self.model_cpu_offload_seq.split("->"):
389
- model = all_model_components.pop(model_str, None)
390
- if not isinstance(model, torch.nn.Module):
391
- continue
392
-
393
- _, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
394
- self._all_hooks.append(hook)
395
-
396
- # CPU offload models that are not in the seq chain unless they are explicitly excluded
397
- # these models will stay on CPU until maybe_free_model_hooks is called
398
- # some models cannot be in the seq chain because they are iteratively called, such as controlnet
399
- for name, model in all_model_components.items():
400
- if not isinstance(model, torch.nn.Module):
401
- continue
402
-
403
- if name in self._exclude_from_cpu_offload:
404
- model.to(device)
405
- else:
406
- _, hook = cpu_offload_with_hook(model, device)
407
- self._all_hooks.append(hook)
408
-
409
- def maybe_free_model_hooks(self):
410
- r"""
411
- Function that offloads all components, removes all model hooks that were added when using
412
- `enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
413
- is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
414
- functions correctly when applying enable_model_cpu_offload.
415
- """
416
- if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
417
- # `enable_model_cpu_offload` has not be called, so silently do nothing
418
- return
419
-
420
- for hook in self._all_hooks:
421
- # offload model and remove hook from model
422
- hook.offload()
423
- hook.remove()
424
-
425
- # make sure the model is in the same state as before calling it
426
- self.enable_model_cpu_offload()
427
 
428
  @synchronize_timer('Encode cond')
429
- def encode_cond(self, image, additional_cond_inputs, do_classifier_free_guidance, dual_guidance):
430
  bsz = image.shape[0]
431
- cond = self.conditioner(image=image, **additional_cond_inputs)
432
 
433
  if do_classifier_free_guidance:
434
- un_cond = self.conditioner.unconditional_embedding(bsz, **additional_cond_inputs)
435
 
436
  if dual_guidance:
437
  un_cond_drop_main = copy.deepcopy(un_cond)
@@ -447,7 +346,7 @@ class Hunyuan3DDiTPipeline:
447
 
448
  cond = cat_recursive(cond, un_cond_drop_main, un_cond)
449
  else:
450
- un_cond = self.conditioner.unconditional_embedding(bsz, **additional_cond_inputs)
451
 
452
  def cat_recursive(a, b):
453
  if isinstance(a, torch.Tensor):
@@ -494,27 +393,25 @@ class Hunyuan3DDiTPipeline:
494
  latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)
495
  return latents
496
 
497
- def prepare_image(self, image) -> dict:
498
  if isinstance(image, str) and not os.path.exists(image):
499
  raise FileNotFoundError(f"Couldn't find image at path {image}")
500
 
501
  if not isinstance(image, list):
502
  image = [image]
503
-
504
- outputs = []
505
  for img in image:
506
- output = self.image_processor(img)
507
- outputs.append(output)
508
-
509
- cond_input = {k: [] for k in outputs[0].keys()}
510
- for output in outputs:
511
- for key, value in output.items():
512
- cond_input[key].append(value)
513
- for key, value in cond_input.items():
514
- if isinstance(value[0], torch.Tensor):
515
- cond_input[key] = torch.cat(value, dim=0)
516
 
517
- return cond_input
 
 
 
 
 
518
 
519
  def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
520
  """
@@ -589,6 +486,7 @@ class Hunyuan3DDiTPipeline:
589
 
590
  image, mask = self.prepare_image(image)
591
  cond = self.encode_cond(image=image,
 
592
  do_classifier_free_guidance=do_classifier_free_guidance,
593
  dual_guidance=dual_guidance)
594
  batch_size = image.shape[0]
@@ -648,17 +546,7 @@ class Hunyuan3DDiTPipeline:
648
  box_v, mc_level, num_chunks, octree_resolution, mc_algo,
649
  )
650
 
651
- def _export(
652
- self,
653
- latents,
654
- output_type='trimesh',
655
- box_v=1.01,
656
- mc_level=0.0,
657
- num_chunks=20000,
658
- octree_resolution=256,
659
- mc_algo='mc',
660
- enable_pbar=True
661
- ):
662
  if not output_type == "latent":
663
  latents = 1. / self.vae.scale_factor * latents
664
  latents = self.vae(latents)
@@ -685,7 +573,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
685
  @torch.inference_mode()
686
  def __call__(
687
  self,
688
- image: Union[str, List[str], Image.Image, dict, List[dict]] = None,
689
  num_inference_steps: int = 50,
690
  timesteps: List[int] = None,
691
  sigmas: List[float] = None,
@@ -713,11 +601,10 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
713
  self.model.guidance_embed is True
714
  )
715
 
716
- cond_inputs = self.prepare_image(image)
717
- image = cond_inputs.pop('image')
718
  cond = self.encode_cond(
719
  image=image,
720
- additional_cond_inputs=cond_inputs,
721
  do_classifier_free_guidance=do_classifier_free_guidance,
722
  dual_guidance=False,
723
  )
 
34
  import yaml
35
  from PIL import Image
36
  from diffusers.utils.torch_utils import randn_tensor
 
37
  from tqdm import tqdm
38
 
39
  from .models.autoencoders import ShapeVAE
40
  from .models.autoencoders import SurfaceExtractors
41
+ from .utils import logger, synchronize_timer
42
 
43
 
44
  def retrieve_timesteps(
 
137
 
138
 
139
  class Hunyuan3DDiTPipeline:
 
 
 
140
  @classmethod
141
  @synchronize_timer('Hunyuan3DDiTPipeline Model Loading')
142
  def from_single_file(
 
217
  dtype=dtype,
218
  device=device,
219
  )
220
+ original_model_path = model_path
221
+ # try local path
222
+ base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
223
+ model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
224
+ logger.info(f'Try to load model from local path: {model_path}')
225
+ if not os.path.exists(model_path):
226
+ logger.info('Model path not exists, try to download from huggingface')
227
+ try:
228
+ import huggingface_hub
229
+ # download from huggingface
230
+ path = huggingface_hub.snapshot_download(repo_id=original_model_path)
231
+ model_path = os.path.join(path, subfolder)
232
+ except ImportError:
233
+ logger.warning(
234
+ "You need to install HuggingFace Hub to load models from the hub."
235
+ )
236
+ raise RuntimeError(f"Model path {model_path} not found")
237
+ except Exception as e:
238
+ raise e
239
+
240
+ if not os.path.exists(model_path):
241
+ raise FileNotFoundError(f"Model path {original_model_path} not found")
242
+
243
+ extension = 'ckpt' if not use_safetensors else 'safetensors'
244
+ variant = '' if variant is None else f'.{variant}'
245
+ ckpt_name = f'model{variant}.{extension}'
246
+ config_path = os.path.join(model_path, 'config.yaml')
247
+ ckpt_path = os.path.join(model_path, ckpt_name)
248
  return cls.from_single_file(
249
  ckpt_path,
250
  config_path,
 
289
  if enabled:
290
  model_path = self.kwargs['from_pretrained_kwargs']['model_path']
291
  turbo_vae_mapping = {
292
+ 'Hunyuan3D-2': 'hunyuan3d-vae-v2-0-turbo',
293
+ 'Hunyuan3D-2s': 'hunyuan3d-vae-v2-s-turbo'
 
294
  }
295
  model_name = model_path.split('/')[-1]
296
  if replace_vae and model_name in turbo_vae_mapping:
297
+ self.vae = ShapeVAE.from_pretrained(model_path, subfolder=turbo_vae_mapping[model_name])
 
 
 
 
 
298
  self.vae.enable_flashvdm_decoder(
299
  enabled=enabled,
300
  adaptive_kv_selection=adaptive_kv_selection,
 
304
  else:
305
  model_path = self.kwargs['from_pretrained_kwargs']['model_path']
306
  vae_mapping = {
307
+ 'Hunyuan3D-2': 'hunyuan3d-vae-v2-0',
308
+ 'Hunyuan3D-2s': 'hunyuan3d-vae-v2-s'
 
309
  }
310
  model_name = model_path.split('/')[-1]
311
  if model_name in vae_mapping:
312
+ self.vae = ShapeVAE.from_pretrained(model_path, subfolder=vae_mapping[model_name])
 
313
  self.vae.enable_flashvdm_decoder(enabled=False)
314
 
315
  def to(self, device=None, dtype=None):
 
 
 
 
 
316
  if device is not None:
317
  self.device = torch.device(device)
318
  self.vae.to(device)
319
  self.model.to(device)
320
  self.conditioner.to(device)
321
+ if dtype is not None:
322
+ self.dtype = dtype
323
+ self.vae.to(dtype=dtype)
324
+ self.model.to(dtype=dtype)
325
+ self.conditioner.to(dtype=dtype)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
326
 
327
  @synchronize_timer('Encode cond')
328
+ def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
329
  bsz = image.shape[0]
330
+ cond = self.conditioner(image=image, mask=mask)
331
 
332
  if do_classifier_free_guidance:
333
+ un_cond = self.conditioner.unconditional_embedding(bsz)
334
 
335
  if dual_guidance:
336
  un_cond_drop_main = copy.deepcopy(un_cond)
 
346
 
347
  cond = cat_recursive(cond, un_cond_drop_main, un_cond)
348
  else:
349
+ un_cond = self.conditioner.unconditional_embedding(bsz)
350
 
351
  def cat_recursive(a, b):
352
  if isinstance(a, torch.Tensor):
 
393
  latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)
394
  return latents
395
 
396
+ def prepare_image(self, image):
397
  if isinstance(image, str) and not os.path.exists(image):
398
  raise FileNotFoundError(f"Couldn't find image at path {image}")
399
 
400
  if not isinstance(image, list):
401
  image = [image]
402
+ image_pts = []
403
+ mask_pts = []
404
  for img in image:
405
+ image_pt, mask_pt = self.image_processor(img, return_mask=True)
406
+ image_pts.append(image_pt)
407
+ mask_pts.append(mask_pt)
 
 
 
 
 
 
 
408
 
409
+ image_pts = torch.cat(image_pts, dim=0).to(self.device, dtype=self.dtype)
410
+ if mask_pts[0] is not None:
411
+ mask_pts = torch.cat(mask_pts, dim=0).to(self.device, dtype=self.dtype)
412
+ else:
413
+ mask_pts = None
414
+ return image_pts, mask_pts
415
 
416
  def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
417
  """
 
486
 
487
  image, mask = self.prepare_image(image)
488
  cond = self.encode_cond(image=image,
489
+ mask=mask,
490
  do_classifier_free_guidance=do_classifier_free_guidance,
491
  dual_guidance=dual_guidance)
492
  batch_size = image.shape[0]
 
546
  box_v, mc_level, num_chunks, octree_resolution, mc_algo,
547
  )
548
 
549
+ def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo, enable_pbar=True):
 
 
 
 
 
 
 
 
 
 
550
  if not output_type == "latent":
551
  latents = 1. / self.vae.scale_factor * latents
552
  latents = self.vae(latents)
 
573
  @torch.inference_mode()
574
  def __call__(
575
  self,
576
+ image: Union[str, List[str], Image.Image] = None,
577
  num_inference_steps: int = 50,
578
  timesteps: List[int] = None,
579
  sigmas: List[float] = None,
 
601
  self.model.guidance_embed is True
602
  )
603
 
604
+ image, mask = self.prepare_image(image)
 
605
  cond = self.encode_cond(
606
  image=image,
607
+ mask=mask,
608
  do_classifier_free_guidance=do_classifier_free_guidance,
609
  dual_guidance=False,
610
  )
hy3dgen/shapegen/postprocessors.py CHANGED
@@ -22,16 +22,13 @@
22
  # fine-tuning enabling code and other elements of the foregoing made publicly available
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
 
25
- import os
26
  import tempfile
27
  from typing import Union
28
 
29
- import numpy as np
30
  import pymeshlab
31
- import torch
32
  import trimesh
33
 
34
- from .models.autoencoders import Latent2MeshOutput
35
  from .utils import synchronize_timer
36
 
37
 
@@ -165,62 +162,3 @@ class DegenerateFaceRemover:
165
 
166
  mesh = export_mesh(mesh, ms)
167
  return mesh
168
-
169
-
170
- def import_pymeshlab_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str]) -> pymeshlab.MeshSet:
171
- if isinstance(mesh, str):
172
- mesh = load_mesh(mesh)
173
- elif isinstance(mesh, Latent2MeshOutput):
174
- mesh = pymeshlab.MeshSet()
175
- mesh_pymeshlab = pymeshlab.Mesh(vertex_matrix=mesh.mesh_v, face_matrix=mesh.mesh_f)
176
- mesh.add_mesh(mesh_pymeshlab, "converted_mesh")
177
-
178
- if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)):
179
- mesh = trimesh2pymeshlab(mesh)
180
-
181
- return mesh
182
-
183
-
184
- def mesh_normalize(mesh):
185
- """
186
- Normalize mesh vertices to sphere
187
- """
188
- scale_factor = 1.2
189
- vtx_pos = np.asarray(mesh.vertices)
190
- max_bb = (vtx_pos - 0).max(0)[0]
191
- min_bb = (vtx_pos - 0).min(0)[0]
192
-
193
- center = (max_bb + min_bb) / 2
194
-
195
- scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0
196
-
197
- vtx_pos = (vtx_pos - center) * (scale_factor / float(scale))
198
- mesh.vertices = vtx_pos
199
-
200
- return mesh
201
-
202
-
203
- class MeshSimplifier:
204
- def __init__(self, executable: str = None):
205
- if executable is None:
206
- CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
207
- executable = os.path.join(CURRENT_DIR, "mesh_simplifier.bin")
208
- self.executable = executable
209
-
210
- @synchronize_timer('MeshSimplifier')
211
- def __call__(
212
- self,
213
- mesh: Union[trimesh.Trimesh],
214
- ) -> Union[trimesh.Trimesh]:
215
- with tempfile.NamedTemporaryFile(suffix='.obj', delete=True) as temp_input:
216
- with tempfile.NamedTemporaryFile(suffix='.obj', delete=True) as temp_output:
217
- mesh.export(temp_input.name)
218
- os.system(f'{self.executable} {temp_input.name} {temp_output.name}')
219
- ms = trimesh.load(temp_output.name, process=False)
220
- if isinstance(ms, trimesh.Scene):
221
- combined_mesh = trimesh.Trimesh()
222
- for geom in ms.geometry.values():
223
- combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
224
- ms = combined_mesh
225
- ms = mesh_normalize(ms)
226
- return ms
 
22
  # fine-tuning enabling code and other elements of the foregoing made publicly available
23
  # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
24
 
 
25
  import tempfile
26
  from typing import Union
27
 
 
28
  import pymeshlab
 
29
  import trimesh
30
 
31
+ from .models.vae import Latent2MeshOutput
32
  from .utils import synchronize_timer
33
 
34
 
 
162
 
163
  mesh = export_mesh(mesh, ms)
164
  return mesh
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hy3dgen/shapegen/preprocessors.py CHANGED
@@ -96,7 +96,7 @@ class ImageProcessorV2:
96
  mask = mask.clip(0, 255).astype(np.uint8)
97
  return result, mask
98
 
99
- def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):
100
  if self.border_ratio is not None:
101
  border_ratio = self.border_ratio
102
  if isinstance(image, str):
@@ -115,74 +115,13 @@ class ImageProcessorV2:
115
  if to_tensor:
116
  image = array_to_tensor(image)
117
  mask = array_to_tensor(mask)
118
-
119
- outputs = {
120
- 'image': image,
121
- 'mask': mask
122
- }
123
- return outputs
124
-
125
-
126
- class MVImageProcessorV2(ImageProcessorV2):
127
- """
128
- view order: front, front clockwise 90, back, front clockwise 270
129
- """
130
- return_view_idx = True
131
-
132
- def __init__(self, size=512, border_ratio=None):
133
- super().__init__(size, border_ratio)
134
- self.view2idx = {
135
- 'front': 0,
136
- 'left': 1,
137
- 'back': 2,
138
- 'right': 3
139
- }
140
-
141
- def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):
142
- if self.border_ratio is not None:
143
- border_ratio = self.border_ratio
144
-
145
- images = []
146
- masks = []
147
- view_idxs = []
148
- for idx, (view_tag, image) in enumerate(image_dict.items()):
149
- view_idxs.append(self.view2idx[view_tag])
150
- if isinstance(image, str):
151
- image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
152
- image, mask = self.recenter(image, border_ratio=border_ratio)
153
- image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
154
- elif isinstance(image, Image.Image):
155
- image = image.convert("RGBA")
156
- image = np.asarray(image)
157
- image, mask = self.recenter(image, border_ratio=border_ratio)
158
-
159
- image = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
160
- mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST)
161
- mask = mask[..., np.newaxis]
162
-
163
- if to_tensor:
164
- image = array_to_tensor(image)
165
- mask = array_to_tensor(mask)
166
- images.append(image)
167
- masks.append(mask)
168
-
169
- zipped_lists = zip(view_idxs, images, masks)
170
- sorted_zipped_lists = sorted(zipped_lists)
171
- view_idxs, images, masks = zip(*sorted_zipped_lists)
172
-
173
- image = torch.cat(images, 0).unsqueeze(0)
174
- mask = torch.cat(masks, 0).unsqueeze(0)
175
- outputs = {
176
- 'image': image,
177
- 'mask': mask,
178
- 'view_idxs': view_idxs
179
- }
180
- return outputs
181
 
182
 
183
  IMAGE_PROCESSORS = {
184
  "v2": ImageProcessorV2,
185
- 'mv_v2': MVImageProcessorV2,
186
  }
187
 
188
  DEFAULT_IMAGEPROCESSOR = 'v2'
 
96
  mask = mask.clip(0, 255).astype(np.uint8)
97
  return result, mask
98
 
99
+ def __call__(self, image, border_ratio=0.15, to_tensor=True, return_mask=False, **kwargs):
100
  if self.border_ratio is not None:
101
  border_ratio = self.border_ratio
102
  if isinstance(image, str):
 
115
  if to_tensor:
116
  image = array_to_tensor(image)
117
  mask = array_to_tensor(mask)
118
+ if return_mask:
119
+ return image, mask
120
+ return image
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
 
123
  IMAGE_PROCESSORS = {
124
  "v2": ImageProcessorV2,
 
125
  }
126
 
127
  DEFAULT_IMAGEPROCESSOR = 'v2'
hy3dgen/shapegen/utils.py CHANGED
@@ -70,40 +70,3 @@ class synchronize_timer:
70
  return result
71
 
72
  return wrapper
73
-
74
-
75
- def smart_load_model(
76
- model_path,
77
- subfolder,
78
- use_safetensors,
79
- variant,
80
- ):
81
- original_model_path = model_path
82
- # try local path
83
- base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
84
- model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
85
- logger.info(f'Try to load model from local path: {model_path}')
86
- if not os.path.exists(model_path):
87
- logger.info('Model path not exists, try to download from huggingface')
88
- try:
89
- import huggingface_hub
90
- # download from huggingface
91
- path = huggingface_hub.snapshot_download(repo_id=original_model_path)
92
- model_path = os.path.join(path, subfolder)
93
- except ImportError:
94
- logger.warning(
95
- "You need to install HuggingFace Hub to load models from the hub."
96
- )
97
- raise RuntimeError(f"Model path {model_path} not found")
98
- except Exception as e:
99
- raise e
100
-
101
- if not os.path.exists(model_path):
102
- raise FileNotFoundError(f"Model path {original_model_path} not found")
103
-
104
- extension = 'ckpt' if not use_safetensors else 'safetensors'
105
- variant = '' if variant is None else f'.{variant}'
106
- ckpt_name = f'model{variant}.{extension}'
107
- config_path = os.path.join(model_path, 'config.yaml')
108
- ckpt_path = os.path.join(model_path, ckpt_name)
109
- return config_path, ckpt_path
 
70
  return result
71
 
72
  return wrapper