Spaces:
Running
on
Zero
Running
on
Zero
Upload 16 files
Browse files- hy3dgen/shapegen/__init__.py +1 -1
- hy3dgen/shapegen/models/__init__.py +1 -1
- hy3dgen/shapegen/models/autoencoders/__init__.py +3 -3
- hy3dgen/shapegen/models/autoencoders/attention_processors.py +92 -0
- hy3dgen/shapegen/models/autoencoders/model.py +20 -1
- hy3dgen/shapegen/models/autoencoders/volume_decoders.py +353 -0
- hy3dgen/shapegen/models/conditioner.py +12 -104
- hy3dgen/shapegen/models/denoisers/__init__.py +1 -0
- hy3dgen/shapegen/models/denoisers/hunyuan3ddit.py +3 -12
- hy3dgen/shapegen/pipelines.py +61 -174
- hy3dgen/shapegen/postprocessors.py +1 -63
- hy3dgen/shapegen/preprocessors.py +4 -65
- hy3dgen/shapegen/utils.py +0 -37
hy3dgen/shapegen/__init__.py
CHANGED
@@ -23,5 +23,5 @@
|
|
23 |
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
|
25 |
from .pipelines import Hunyuan3DDiTPipeline, Hunyuan3DDiTFlowMatchingPipeline
|
26 |
-
from .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover
|
27 |
from .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR
|
|
|
23 |
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
|
25 |
from .pipelines import Hunyuan3DDiTPipeline, Hunyuan3DDiTFlowMatchingPipeline
|
26 |
+
from .postprocessors import FaceReducer, FloaterRemover, DegenerateFaceRemover
|
27 |
from .preprocessors import ImageProcessorV2, IMAGE_PROCESSORS, DEFAULT_IMAGEPROCESSOR
|
hy3dgen/shapegen/models/__init__.py
CHANGED
@@ -25,4 +25,4 @@
|
|
25 |
|
26 |
from .autoencoders import ShapeVAE
|
27 |
from .conditioner import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
|
28 |
-
from .denoisers import Hunyuan3DDiT
|
|
|
25 |
|
26 |
from .autoencoders import ShapeVAE
|
27 |
from .conditioner import DualImageEncoder, SingleImageEncoder, DinoImageEncoder, CLIPImageEncoder
|
28 |
+
from .denoisers import HunYuanDiTPlain, Hunyuan3DDiT
|
hy3dgen/shapegen/models/autoencoders/__init__.py
CHANGED
@@ -1,5 +1,5 @@
|
|
1 |
from .attention_blocks import CrossAttentionDecoder
|
2 |
-
from .attention_processors import CrossAttentionProcessor
|
3 |
from .model import ShapeVAE, VectsetVAE
|
4 |
-
from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor
|
5 |
-
from .volume_decoders import
|
|
|
1 |
from .attention_blocks import CrossAttentionDecoder
|
2 |
+
from .attention_processors import FlashVDMCrossAttentionProcessor, CrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
|
3 |
from .model import ShapeVAE, VectsetVAE
|
4 |
+
from .surface_extractors import SurfaceExtractors, MCSurfaceExtractor, DMCSurfaceExtractor
|
5 |
+
from .volume_decoders import HierarchicalVolumeDecoding, FlashVDMVolumeDecoding, VanillaVolumeDecoder
|
hy3dgen/shapegen/models/autoencoders/attention_processors.py
CHANGED
@@ -17,3 +17,95 @@ class CrossAttentionProcessor:
|
|
17 |
out = scaled_dot_product_attention(q, k, v)
|
18 |
return out
|
19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
17 |
out = scaled_dot_product_attention(q, k, v)
|
18 |
return out
|
19 |
|
20 |
+
|
21 |
+
class FlashVDMCrossAttentionProcessor:
|
22 |
+
def __init__(self, topk=None):
|
23 |
+
self.topk = topk
|
24 |
+
|
25 |
+
def __call__(self, attn, q, k, v):
|
26 |
+
if k.shape[-2] == 3072:
|
27 |
+
topk = 1024
|
28 |
+
elif k.shape[-2] == 512:
|
29 |
+
topk = 256
|
30 |
+
else:
|
31 |
+
topk = k.shape[-2] // 3
|
32 |
+
|
33 |
+
if self.topk is True:
|
34 |
+
q1 = q[:, :, ::100, :]
|
35 |
+
sim = q1 @ k.transpose(-1, -2)
|
36 |
+
sim = torch.mean(sim, -2)
|
37 |
+
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
38 |
+
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
39 |
+
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
40 |
+
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
41 |
+
out = scaled_dot_product_attention(q, k0, v0)
|
42 |
+
elif self.topk is False:
|
43 |
+
out = scaled_dot_product_attention(q, k, v)
|
44 |
+
else:
|
45 |
+
idx, counts = self.topk
|
46 |
+
start = 0
|
47 |
+
outs = []
|
48 |
+
for grid_coord, count in zip(idx, counts):
|
49 |
+
end = start + count
|
50 |
+
q_chunk = q[:, :, start:end, :]
|
51 |
+
q1 = q_chunk[:, :, ::50, :]
|
52 |
+
sim = q1 @ k.transpose(-1, -2)
|
53 |
+
sim = torch.mean(sim, -2)
|
54 |
+
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
55 |
+
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
56 |
+
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
57 |
+
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
58 |
+
out = scaled_dot_product_attention(q_chunk, k0, v0)
|
59 |
+
outs.append(out)
|
60 |
+
start += count
|
61 |
+
out = torch.cat(outs, dim=-2)
|
62 |
+
self.topk = False
|
63 |
+
return out
|
64 |
+
|
65 |
+
|
66 |
+
class FlashVDMTopMCrossAttentionProcessor:
|
67 |
+
def __init__(self, topk=None):
|
68 |
+
self.topk = topk
|
69 |
+
|
70 |
+
def __call__(self, attn, q, k, v):
|
71 |
+
if k.shape[-2] == 3072:
|
72 |
+
topk = 1024
|
73 |
+
elif k.shape[-2] == 512:
|
74 |
+
topk = 256
|
75 |
+
else:
|
76 |
+
topk = k.shape[-2] // 3
|
77 |
+
|
78 |
+
if self.topk is True:
|
79 |
+
q1 = q[:, :, ::100, :]
|
80 |
+
sim = q1 @ k.transpose(-1, -2)
|
81 |
+
sim = torch.mean(sim, -2)
|
82 |
+
topk_ind = torch.topk(sim, dim=-1, k=topk).indices.squeeze(-2).unsqueeze(-1)
|
83 |
+
topk_ind = topk_ind.expand(-1, -1, -1, v.shape[-1])
|
84 |
+
v0 = torch.gather(v, dim=-2, index=topk_ind)
|
85 |
+
k0 = torch.gather(k, dim=-2, index=topk_ind)
|
86 |
+
out = scaled_dot_product_attention(q, k0, v0)
|
87 |
+
elif self.topk is False:
|
88 |
+
out = scaled_dot_product_attention(q, k, v)
|
89 |
+
else:
|
90 |
+
idx, counts = self.topk
|
91 |
+
start = 0
|
92 |
+
outs = []
|
93 |
+
for grid_coord, count in zip(idx, counts):
|
94 |
+
end = start + count
|
95 |
+
q_chunk = q[:, :, start:end, :]
|
96 |
+
q1 = q_chunk[:, :, ::30, :]
|
97 |
+
sim = q1 @ k.transpose(-1, -2)
|
98 |
+
# sim = sim.to(torch.float32)
|
99 |
+
sim = sim.softmax(-1)
|
100 |
+
sim = torch.mean(sim, 1)
|
101 |
+
activated_token = torch.where(sim > 1e-6)[2]
|
102 |
+
index = torch.unique(activated_token, return_counts=True)[0].unsqueeze(0).unsqueeze(0).unsqueeze(-1)
|
103 |
+
index = index.expand(-1, v.shape[1], -1, v.shape[-1])
|
104 |
+
v0 = torch.gather(v, dim=-2, index=index)
|
105 |
+
k0 = torch.gather(k, dim=-2, index=index)
|
106 |
+
out = scaled_dot_product_attention(q_chunk, k0, v0) # bhnc
|
107 |
+
outs.append(out)
|
108 |
+
start += count
|
109 |
+
out = torch.cat(outs, dim=-2)
|
110 |
+
self.topk = False
|
111 |
+
return out
|
hy3dgen/shapegen/models/autoencoders/model.py
CHANGED
@@ -6,7 +6,7 @@ import yaml
|
|
6 |
|
7 |
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder
|
8 |
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
|
9 |
-
from .volume_decoders import VanillaVolumeDecoder
|
10 |
from ...utils import logger, synchronize_timer
|
11 |
|
12 |
|
@@ -117,6 +117,25 @@ class VectsetVAE(nn.Module):
|
|
117 |
outputs = self.surface_extractor(grid_logits, **kwargs)
|
118 |
return outputs
|
119 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
120 |
|
121 |
class ShapeVAE(VectsetVAE):
|
122 |
def __init__(
|
|
|
6 |
|
7 |
from .attention_blocks import FourierEmbedder, Transformer, CrossAttentionDecoder
|
8 |
from .surface_extractors import MCSurfaceExtractor, SurfaceExtractors
|
9 |
+
from .volume_decoders import VanillaVolumeDecoder, FlashVDMVolumeDecoding, HierarchicalVolumeDecoding
|
10 |
from ...utils import logger, synchronize_timer
|
11 |
|
12 |
|
|
|
117 |
outputs = self.surface_extractor(grid_logits, **kwargs)
|
118 |
return outputs
|
119 |
|
120 |
+
def enable_flashvdm_decoder(
|
121 |
+
self,
|
122 |
+
enabled: bool = True,
|
123 |
+
adaptive_kv_selection=True,
|
124 |
+
topk_mode='mean',
|
125 |
+
mc_algo='dmc',
|
126 |
+
):
|
127 |
+
if enabled:
|
128 |
+
if adaptive_kv_selection:
|
129 |
+
self.volume_decoder = FlashVDMVolumeDecoding(topk_mode)
|
130 |
+
else:
|
131 |
+
self.volume_decoder = HierarchicalVolumeDecoding()
|
132 |
+
if mc_algo not in SurfaceExtractors.keys():
|
133 |
+
raise ValueError(f'Unsupported mc_algo {mc_algo}, available: {list(SurfaceExtractors.keys())}')
|
134 |
+
self.surface_extractor = SurfaceExtractors[mc_algo]()
|
135 |
+
else:
|
136 |
+
self.volume_decoder = VanillaVolumeDecoder()
|
137 |
+
self.surface_extractor = MCSurfaceExtractor()
|
138 |
+
|
139 |
|
140 |
class ShapeVAE(VectsetVAE):
|
141 |
def __init__(
|
hy3dgen/shapegen/models/autoencoders/volume_decoders.py
CHANGED
@@ -8,9 +8,111 @@ from einops import repeat
|
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
from .attention_blocks import CrossAttentionDecoder
|
|
|
11 |
from ...utils import logger
|
12 |
|
13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
14 |
def generate_dense_grid_points(
|
15 |
bbox_min: np.ndarray,
|
16 |
bbox_max: np.ndarray,
|
@@ -74,3 +176,254 @@ class VanillaVolumeDecoder:
|
|
74 |
return grid_logits
|
75 |
|
76 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
8 |
from tqdm import tqdm
|
9 |
|
10 |
from .attention_blocks import CrossAttentionDecoder
|
11 |
+
from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor
|
12 |
from ...utils import logger
|
13 |
|
14 |
|
15 |
+
def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float):
|
16 |
+
"""
|
17 |
+
修复维度问题的PyTorch实现
|
18 |
+
Args:
|
19 |
+
input_tensor: shape [D, D, D], torch.float16
|
20 |
+
alpha: 标量偏移值
|
21 |
+
Returns:
|
22 |
+
mask: shape [D, D, D], torch.int32 表面掩码
|
23 |
+
"""
|
24 |
+
device = input_tensor.device
|
25 |
+
D = input_tensor.shape[0]
|
26 |
+
signed_val = 0.0
|
27 |
+
|
28 |
+
# 添加偏移并处理无效值
|
29 |
+
val = input_tensor + alpha
|
30 |
+
valid_mask = val > -9000 # 假设-9000是无效值
|
31 |
+
|
32 |
+
# 改进的邻居获取函数(保持维度一致)
|
33 |
+
def get_neighbor(t, shift, axis):
|
34 |
+
"""根据指定轴进行位移并保持维度一致"""
|
35 |
+
if shift == 0:
|
36 |
+
return t.clone()
|
37 |
+
|
38 |
+
# 确定填充轴(输入为[D, D, D]对应z,y,x轴)
|
39 |
+
pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后]
|
40 |
+
|
41 |
+
# 根据轴类型设置填充
|
42 |
+
if axis == 0: # x轴(最后一个维度)
|
43 |
+
pad_idx = 0 if shift > 0 else 1
|
44 |
+
pad_dims[pad_idx] = abs(shift)
|
45 |
+
elif axis == 1: # y轴(中间维度)
|
46 |
+
pad_idx = 2 if shift > 0 else 3
|
47 |
+
pad_dims[pad_idx] = abs(shift)
|
48 |
+
elif axis == 2: # z轴(第一个维度)
|
49 |
+
pad_idx = 4 if shift > 0 else 5
|
50 |
+
pad_dims[pad_idx] = abs(shift)
|
51 |
+
|
52 |
+
# 执行填充(添加batch和channel维度适配F.pad)
|
53 |
+
padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate') # 反转顺序适配F.pad
|
54 |
+
|
55 |
+
# 构建动态切片索引
|
56 |
+
slice_dims = [slice(None)] * 3 # 初始化为全切片
|
57 |
+
if axis == 0: # x轴(dim=2)
|
58 |
+
if shift > 0:
|
59 |
+
slice_dims[0] = slice(shift, None)
|
60 |
+
else:
|
61 |
+
slice_dims[0] = slice(None, shift)
|
62 |
+
elif axis == 1: # y轴(dim=1)
|
63 |
+
if shift > 0:
|
64 |
+
slice_dims[1] = slice(shift, None)
|
65 |
+
else:
|
66 |
+
slice_dims[1] = slice(None, shift)
|
67 |
+
elif axis == 2: # z轴(dim=0)
|
68 |
+
if shift > 0:
|
69 |
+
slice_dims[2] = slice(shift, None)
|
70 |
+
else:
|
71 |
+
slice_dims[2] = slice(None, shift)
|
72 |
+
|
73 |
+
# 应用切片并恢复维度
|
74 |
+
padded = padded.squeeze(0).squeeze(0)
|
75 |
+
sliced = padded[slice_dims]
|
76 |
+
return sliced
|
77 |
+
|
78 |
+
# 获取各方向邻居(确保维度一致)
|
79 |
+
left = get_neighbor(val, 1, axis=0) # x方向
|
80 |
+
right = get_neighbor(val, -1, axis=0)
|
81 |
+
back = get_neighbor(val, 1, axis=1) # y方向
|
82 |
+
front = get_neighbor(val, -1, axis=1)
|
83 |
+
down = get_neighbor(val, 1, axis=2) # z方向
|
84 |
+
up = get_neighbor(val, -1, axis=2)
|
85 |
+
|
86 |
+
# 处理边界无效值(使用where保持维度一致)
|
87 |
+
def safe_where(neighbor):
|
88 |
+
return torch.where(neighbor > -9000, neighbor, val)
|
89 |
+
|
90 |
+
left = safe_where(left)
|
91 |
+
right = safe_where(right)
|
92 |
+
back = safe_where(back)
|
93 |
+
front = safe_where(front)
|
94 |
+
down = safe_where(down)
|
95 |
+
up = safe_where(up)
|
96 |
+
|
97 |
+
# 计算符号一致性(转换为float32确保精度)
|
98 |
+
sign = torch.sign(val.to(torch.float32))
|
99 |
+
neighbors_sign = torch.stack([
|
100 |
+
torch.sign(left.to(torch.float32)),
|
101 |
+
torch.sign(right.to(torch.float32)),
|
102 |
+
torch.sign(back.to(torch.float32)),
|
103 |
+
torch.sign(front.to(torch.float32)),
|
104 |
+
torch.sign(down.to(torch.float32)),
|
105 |
+
torch.sign(up.to(torch.float32))
|
106 |
+
], dim=0)
|
107 |
+
|
108 |
+
# 检查所有符号是否一致
|
109 |
+
same_sign = torch.all(neighbors_sign == sign, dim=0)
|
110 |
+
|
111 |
+
# 生成最终掩码
|
112 |
+
mask = (~same_sign).to(torch.int32)
|
113 |
+
return mask * valid_mask.to(torch.int32)
|
114 |
+
|
115 |
+
|
116 |
def generate_dense_grid_points(
|
117 |
bbox_min: np.ndarray,
|
118 |
bbox_max: np.ndarray,
|
|
|
176 |
return grid_logits
|
177 |
|
178 |
|
179 |
+
class HierarchicalVolumeDecoding:
|
180 |
+
@torch.no_grad()
|
181 |
+
def __call__(
|
182 |
+
self,
|
183 |
+
latents: torch.FloatTensor,
|
184 |
+
geo_decoder: Callable,
|
185 |
+
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
186 |
+
num_chunks: int = 10000,
|
187 |
+
mc_level: float = 0.0,
|
188 |
+
octree_resolution: int = None,
|
189 |
+
min_resolution: int = 63,
|
190 |
+
enable_pbar: bool = True,
|
191 |
+
**kwargs,
|
192 |
+
):
|
193 |
+
device = latents.device
|
194 |
+
dtype = latents.dtype
|
195 |
+
|
196 |
+
resolutions = []
|
197 |
+
if octree_resolution < min_resolution:
|
198 |
+
resolutions.append(octree_resolution)
|
199 |
+
while octree_resolution >= min_resolution:
|
200 |
+
resolutions.append(octree_resolution)
|
201 |
+
octree_resolution = octree_resolution // 2
|
202 |
+
resolutions.reverse()
|
203 |
+
|
204 |
+
# 1. generate query points
|
205 |
+
if isinstance(bounds, float):
|
206 |
+
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
207 |
+
bbox_min = np.array(bounds[0:3])
|
208 |
+
bbox_max = np.array(bounds[3:6])
|
209 |
+
bbox_size = bbox_max - bbox_min
|
210 |
+
|
211 |
+
xyz_samples, grid_size, length = generate_dense_grid_points(
|
212 |
+
bbox_min=bbox_min,
|
213 |
+
bbox_max=bbox_max,
|
214 |
+
octree_resolution=resolutions[0],
|
215 |
+
indexing="ij"
|
216 |
+
)
|
217 |
+
|
218 |
+
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
219 |
+
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
220 |
+
|
221 |
+
grid_size = np.array(grid_size)
|
222 |
+
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3)
|
223 |
+
|
224 |
+
# 2. latents to 3d volume
|
225 |
+
batch_logits = []
|
226 |
+
batch_size = latents.shape[0]
|
227 |
+
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks),
|
228 |
+
desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"):
|
229 |
+
queries = xyz_samples[start: start + num_chunks, :]
|
230 |
+
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
231 |
+
logits = geo_decoder(queries=batch_queries, latents=latents)
|
232 |
+
batch_logits.append(logits)
|
233 |
+
|
234 |
+
grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2]))
|
235 |
+
|
236 |
+
for octree_depth_now in resolutions[1:]:
|
237 |
+
grid_size = np.array([octree_depth_now + 1] * 3)
|
238 |
+
resolution = bbox_size / octree_depth_now
|
239 |
+
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
240 |
+
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
241 |
+
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
242 |
+
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
243 |
+
|
244 |
+
if octree_depth_now == resolutions[-1]:
|
245 |
+
expand_num = 0
|
246 |
+
else:
|
247 |
+
expand_num = 1
|
248 |
+
for i in range(expand_num):
|
249 |
+
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
250 |
+
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
251 |
+
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
252 |
+
for i in range(2 - expand_num):
|
253 |
+
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
254 |
+
nidx = torch.where(next_index > 0)
|
255 |
+
|
256 |
+
next_points = torch.stack(nidx, dim=1)
|
257 |
+
next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) +
|
258 |
+
torch.tensor(bbox_min, dtype=next_points.dtype, device=device))
|
259 |
+
batch_logits = []
|
260 |
+
for start in tqdm(range(0, next_points.shape[0], num_chunks),
|
261 |
+
desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"):
|
262 |
+
queries = next_points[start: start + num_chunks, :]
|
263 |
+
batch_queries = repeat(queries, "p c -> b p c", b=batch_size)
|
264 |
+
logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents)
|
265 |
+
batch_logits.append(logits)
|
266 |
+
grid_logits = torch.cat(batch_logits, dim=1)
|
267 |
+
next_logits[nidx] = grid_logits[0, ..., 0]
|
268 |
+
grid_logits = next_logits.unsqueeze(0)
|
269 |
+
grid_logits[grid_logits == -10000.] = float('nan')
|
270 |
+
|
271 |
+
return grid_logits
|
272 |
+
|
273 |
+
|
274 |
+
class FlashVDMVolumeDecoding:
|
275 |
+
def __init__(self, topk_mode='mean'):
|
276 |
+
if topk_mode not in ['mean', 'merge']:
|
277 |
+
raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}')
|
278 |
+
|
279 |
+
if topk_mode == 'mean':
|
280 |
+
self.processor = FlashVDMCrossAttentionProcessor()
|
281 |
+
else:
|
282 |
+
self.processor = FlashVDMTopMCrossAttentionProcessor()
|
283 |
+
|
284 |
+
@torch.no_grad()
|
285 |
+
def __call__(
|
286 |
+
self,
|
287 |
+
latents: torch.FloatTensor,
|
288 |
+
geo_decoder: CrossAttentionDecoder,
|
289 |
+
bounds: Union[Tuple[float], List[float], float] = 1.01,
|
290 |
+
num_chunks: int = 10000,
|
291 |
+
mc_level: float = 0.0,
|
292 |
+
octree_resolution: int = None,
|
293 |
+
min_resolution: int = 63,
|
294 |
+
mini_grid_num: int = 4,
|
295 |
+
enable_pbar: bool = True,
|
296 |
+
**kwargs,
|
297 |
+
):
|
298 |
+
processor = self.processor
|
299 |
+
geo_decoder.set_cross_attention_processor(processor)
|
300 |
+
|
301 |
+
device = latents.device
|
302 |
+
dtype = latents.dtype
|
303 |
+
|
304 |
+
resolutions = []
|
305 |
+
if octree_resolution < min_resolution:
|
306 |
+
resolutions.append(octree_resolution)
|
307 |
+
while octree_resolution >= min_resolution:
|
308 |
+
resolutions.append(octree_resolution)
|
309 |
+
octree_resolution = octree_resolution // 2
|
310 |
+
resolutions.reverse()
|
311 |
+
resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1
|
312 |
+
for i, resolution in enumerate(resolutions[1:]):
|
313 |
+
resolutions[i + 1] = resolutions[0] * 2 ** (i + 1)
|
314 |
+
|
315 |
+
logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}")
|
316 |
+
|
317 |
+
# 1. generate query points
|
318 |
+
if isinstance(bounds, float):
|
319 |
+
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds]
|
320 |
+
bbox_min = np.array(bounds[0:3])
|
321 |
+
bbox_max = np.array(bounds[3:6])
|
322 |
+
bbox_size = bbox_max - bbox_min
|
323 |
+
|
324 |
+
xyz_samples, grid_size, length = generate_dense_grid_points(
|
325 |
+
bbox_min=bbox_min,
|
326 |
+
bbox_max=bbox_max,
|
327 |
+
octree_resolution=resolutions[0],
|
328 |
+
indexing="ij"
|
329 |
+
)
|
330 |
+
|
331 |
+
dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype)
|
332 |
+
dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device))
|
333 |
+
|
334 |
+
grid_size = np.array(grid_size)
|
335 |
+
|
336 |
+
# 2. latents to 3d volume
|
337 |
+
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype)
|
338 |
+
batch_size = latents.shape[0]
|
339 |
+
mini_grid_size = xyz_samples.shape[0] // mini_grid_num
|
340 |
+
xyz_samples = xyz_samples.view(
|
341 |
+
mini_grid_num, mini_grid_size,
|
342 |
+
mini_grid_num, mini_grid_size,
|
343 |
+
mini_grid_num, mini_grid_size, 3
|
344 |
+
).permute(
|
345 |
+
0, 2, 4, 1, 3, 5, 6
|
346 |
+
).reshape(
|
347 |
+
-1, mini_grid_size * mini_grid_size * mini_grid_size, 3
|
348 |
+
)
|
349 |
+
batch_logits = []
|
350 |
+
num_batchs = max(num_chunks // xyz_samples.shape[1], 1)
|
351 |
+
for start in tqdm(range(0, xyz_samples.shape[0], num_batchs),
|
352 |
+
desc=f"FlashVDM Volume Decoding", disable=not enable_pbar):
|
353 |
+
queries = xyz_samples[start: start + num_batchs, :]
|
354 |
+
batch = queries.shape[0]
|
355 |
+
batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch)
|
356 |
+
processor.topk = True
|
357 |
+
logits = geo_decoder(queries=queries, latents=batch_latents)
|
358 |
+
batch_logits.append(logits)
|
359 |
+
grid_logits = torch.cat(batch_logits, dim=0).reshape(
|
360 |
+
mini_grid_num, mini_grid_num, mini_grid_num,
|
361 |
+
mini_grid_size, mini_grid_size,
|
362 |
+
mini_grid_size
|
363 |
+
).permute(0, 3, 1, 4, 2, 5).contiguous().view(
|
364 |
+
(batch_size, grid_size[0], grid_size[1], grid_size[2])
|
365 |
+
)
|
366 |
+
|
367 |
+
for octree_depth_now in resolutions[1:]:
|
368 |
+
grid_size = np.array([octree_depth_now + 1] * 3)
|
369 |
+
resolution = bbox_size / octree_depth_now
|
370 |
+
next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device)
|
371 |
+
next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device)
|
372 |
+
curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level)
|
373 |
+
curr_points += grid_logits.squeeze(0).abs() < 0.95
|
374 |
+
|
375 |
+
if octree_depth_now == resolutions[-1]:
|
376 |
+
expand_num = 0
|
377 |
+
else:
|
378 |
+
expand_num = 1
|
379 |
+
for i in range(expand_num):
|
380 |
+
curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0)
|
381 |
+
(cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0)
|
382 |
+
|
383 |
+
next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1
|
384 |
+
for i in range(2 - expand_num):
|
385 |
+
next_index = dilate(next_index.unsqueeze(0)).squeeze(0)
|
386 |
+
nidx = torch.where(next_index > 0)
|
387 |
+
|
388 |
+
next_points = torch.stack(nidx, dim=1)
|
389 |
+
next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) +
|
390 |
+
torch.tensor(bbox_min, dtype=torch.float32, device=device))
|
391 |
+
|
392 |
+
query_grid_num = 6
|
393 |
+
min_val = next_points.min(axis=0).values
|
394 |
+
max_val = next_points.max(axis=0).values
|
395 |
+
vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001)
|
396 |
+
index = torch.floor(vol_queries_index).long()
|
397 |
+
index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2]
|
398 |
+
index = index.sort()
|
399 |
+
next_points = next_points[index.indices].unsqueeze(0).contiguous()
|
400 |
+
unique_values = torch.unique(index.values, return_counts=True)
|
401 |
+
grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device)
|
402 |
+
input_grid = [[], []]
|
403 |
+
logits_grid_list = []
|
404 |
+
start_num = 0
|
405 |
+
sum_num = 0
|
406 |
+
for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()):
|
407 |
+
if sum_num + count < num_chunks or sum_num == 0:
|
408 |
+
sum_num += count
|
409 |
+
input_grid[0].append(grid_index)
|
410 |
+
input_grid[1].append(count)
|
411 |
+
else:
|
412 |
+
processor.topk = input_grid
|
413 |
+
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
414 |
+
start_num = start_num + sum_num
|
415 |
+
logits_grid_list.append(logits_grid)
|
416 |
+
input_grid = [[grid_index], [count]]
|
417 |
+
sum_num = count
|
418 |
+
if sum_num > 0:
|
419 |
+
processor.topk = input_grid
|
420 |
+
logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents)
|
421 |
+
logits_grid_list.append(logits_grid)
|
422 |
+
logits_grid = torch.cat(logits_grid_list, dim=1)
|
423 |
+
grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1)
|
424 |
+
next_logits[nidx] = grid_logits
|
425 |
+
grid_logits = next_logits.unsqueeze(0)
|
426 |
+
|
427 |
+
grid_logits[grid_logits == -10000.] = float('nan')
|
428 |
+
|
429 |
+
return grid_logits
|
hy3dgen/shapegen/models/conditioner.py
CHANGED
@@ -22,7 +22,6 @@
|
|
22 |
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
23 |
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
|
25 |
-
import numpy as np
|
26 |
import torch
|
27 |
import torch.nn as nn
|
28 |
from torchvision import transforms
|
@@ -34,26 +33,6 @@ from transformers import (
|
|
34 |
)
|
35 |
|
36 |
|
37 |
-
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
38 |
-
"""
|
39 |
-
embed_dim: output dimension for each position
|
40 |
-
pos: a list of positions to be encoded: size (M,)
|
41 |
-
out: (M, D)
|
42 |
-
"""
|
43 |
-
assert embed_dim % 2 == 0
|
44 |
-
omega = np.arange(embed_dim // 2, dtype=np.float64)
|
45 |
-
omega /= embed_dim / 2.
|
46 |
-
omega = 1. / 10000 ** omega # (D/2,)
|
47 |
-
|
48 |
-
pos = pos.reshape(-1) # (M,)
|
49 |
-
out = np.einsum('m,d->md', pos, omega) # (M, D/2), outer product
|
50 |
-
|
51 |
-
emb_sin = np.sin(out) # (M, D/2)
|
52 |
-
emb_cos = np.cos(out) # (M, D/2)
|
53 |
-
|
54 |
-
return np.concatenate([emb_sin, emb_cos], axis=1)
|
55 |
-
|
56 |
-
|
57 |
class ImageEncoder(nn.Module):
|
58 |
def __init__(
|
59 |
self,
|
@@ -88,7 +67,7 @@ class ImageEncoder(nn.Module):
|
|
88 |
]
|
89 |
)
|
90 |
|
91 |
-
def forward(self, image, mask=None, value_range=(-1, 1)
|
92 |
if value_range is not None:
|
93 |
low, high = value_range
|
94 |
image = (image - low) / (high - low)
|
@@ -103,7 +82,7 @@ class ImageEncoder(nn.Module):
|
|
103 |
|
104 |
return last_hidden_state
|
105 |
|
106 |
-
def unconditional_embedding(self, batch_size
|
107 |
device = next(self.model.parameters()).device
|
108 |
dtype = next(self.model.parameters()).dtype
|
109 |
zero = torch.zeros(
|
@@ -131,82 +110,11 @@ class DinoImageEncoder(ImageEncoder):
|
|
131 |
std = [0.229, 0.224, 0.225]
|
132 |
|
133 |
|
134 |
-
class DinoImageEncoderMV(DinoImageEncoder):
|
135 |
-
def __init__(
|
136 |
-
self,
|
137 |
-
version=None,
|
138 |
-
config=None,
|
139 |
-
use_cls_token=True,
|
140 |
-
image_size=224,
|
141 |
-
view_num=4,
|
142 |
-
**kwargs,
|
143 |
-
):
|
144 |
-
super().__init__(version, config, use_cls_token, image_size, **kwargs)
|
145 |
-
self.view_num = view_num
|
146 |
-
self.num_patches = self.num_patches
|
147 |
-
pos = np.arange(self.view_num, dtype=np.float32)
|
148 |
-
view_embedding = torch.from_numpy(
|
149 |
-
get_1d_sincos_pos_embed_from_grid(self.model.config.hidden_size, pos)).float()
|
150 |
-
|
151 |
-
view_embedding = view_embedding.unsqueeze(1).repeat(1, self.num_patches, 1)
|
152 |
-
self.view_embed = view_embedding.unsqueeze(0)
|
153 |
-
|
154 |
-
def forward(self, image, mask=None, value_range=(-1, 1), view_idxs=None):
|
155 |
-
if value_range is not None:
|
156 |
-
low, high = value_range
|
157 |
-
image = (image - low) / (high - low)
|
158 |
-
|
159 |
-
image = image.to(self.model.device, dtype=self.model.dtype)
|
160 |
-
|
161 |
-
bs, num_views, c, h, w = image.shape
|
162 |
-
image = image.view(bs * num_views, c, h, w)
|
163 |
-
|
164 |
-
inputs = self.transform(image)
|
165 |
-
outputs = self.model(inputs)
|
166 |
-
|
167 |
-
last_hidden_state = outputs.last_hidden_state
|
168 |
-
last_hidden_state = last_hidden_state.view(
|
169 |
-
bs, num_views, last_hidden_state.shape[-2],
|
170 |
-
last_hidden_state.shape[-1]
|
171 |
-
)
|
172 |
-
|
173 |
-
view_embedding = self.view_embed.to(last_hidden_state.dtype).to(last_hidden_state.device)
|
174 |
-
if view_idxs is not None:
|
175 |
-
assert len(view_idxs) == bs
|
176 |
-
view_embeddings = []
|
177 |
-
for i in range(bs):
|
178 |
-
view_idx = view_idxs[i]
|
179 |
-
assert num_views == len(view_idx)
|
180 |
-
view_embeddings.append(self.view_embed[:, view_idx, ...])
|
181 |
-
view_embedding = torch.cat(view_embeddings, 0).to(last_hidden_state.dtype).to(last_hidden_state.device)
|
182 |
-
|
183 |
-
if num_views != self.view_num:
|
184 |
-
view_embedding = view_embedding[:, :num_views, ...]
|
185 |
-
last_hidden_state = last_hidden_state + view_embedding
|
186 |
-
last_hidden_state = last_hidden_state.view(bs, num_views * last_hidden_state.shape[-2],
|
187 |
-
last_hidden_state.shape[-1])
|
188 |
-
return last_hidden_state
|
189 |
-
|
190 |
-
def unconditional_embedding(self, batch_size, view_idxs=None, **kwargs):
|
191 |
-
device = next(self.model.parameters()).device
|
192 |
-
dtype = next(self.model.parameters()).dtype
|
193 |
-
zero = torch.zeros(
|
194 |
-
batch_size,
|
195 |
-
self.num_patches * len(view_idxs[0]),
|
196 |
-
self.model.config.hidden_size,
|
197 |
-
device=device,
|
198 |
-
dtype=dtype,
|
199 |
-
)
|
200 |
-
return zero
|
201 |
-
|
202 |
-
|
203 |
def build_image_encoder(config):
|
204 |
if config['type'] == 'CLIPImageEncoder':
|
205 |
return CLIPImageEncoder(**config['kwargs'])
|
206 |
elif config['type'] == 'DinoImageEncoder':
|
207 |
return DinoImageEncoder(**config['kwargs'])
|
208 |
-
elif config['type'] == 'DinoImageEncoderMV':
|
209 |
-
return DinoImageEncoderMV(**config['kwargs'])
|
210 |
else:
|
211 |
raise ValueError(f'Unknown image encoder type: {config["type"]}')
|
212 |
|
@@ -221,17 +129,17 @@ class DualImageEncoder(nn.Module):
|
|
221 |
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
222 |
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
223 |
|
224 |
-
def forward(self, image, mask=None
|
225 |
outputs = {
|
226 |
-
'main': self.main_image_encoder(image, mask=mask
|
227 |
-
'additional': self.additional_image_encoder(image, mask=mask
|
228 |
}
|
229 |
return outputs
|
230 |
|
231 |
-
def unconditional_embedding(self, batch_size
|
232 |
outputs = {
|
233 |
-
'main': self.main_image_encoder.unconditional_embedding(batch_size
|
234 |
-
'additional': self.additional_image_encoder.unconditional_embedding(batch_size
|
235 |
}
|
236 |
return outputs
|
237 |
|
@@ -244,14 +152,14 @@ class SingleImageEncoder(nn.Module):
|
|
244 |
super().__init__()
|
245 |
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
246 |
|
247 |
-
def forward(self, image, mask=None
|
248 |
outputs = {
|
249 |
-
'main': self.main_image_encoder(image, mask=mask
|
250 |
}
|
251 |
return outputs
|
252 |
|
253 |
-
def unconditional_embedding(self, batch_size
|
254 |
outputs = {
|
255 |
-
'main': self.main_image_encoder.unconditional_embedding(batch_size
|
256 |
}
|
257 |
return outputs
|
|
|
22 |
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
23 |
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
|
|
|
25 |
import torch
|
26 |
import torch.nn as nn
|
27 |
from torchvision import transforms
|
|
|
33 |
)
|
34 |
|
35 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
36 |
class ImageEncoder(nn.Module):
|
37 |
def __init__(
|
38 |
self,
|
|
|
67 |
]
|
68 |
)
|
69 |
|
70 |
+
def forward(self, image, mask=None, value_range=(-1, 1)):
|
71 |
if value_range is not None:
|
72 |
low, high = value_range
|
73 |
image = (image - low) / (high - low)
|
|
|
82 |
|
83 |
return last_hidden_state
|
84 |
|
85 |
+
def unconditional_embedding(self, batch_size):
|
86 |
device = next(self.model.parameters()).device
|
87 |
dtype = next(self.model.parameters()).dtype
|
88 |
zero = torch.zeros(
|
|
|
110 |
std = [0.229, 0.224, 0.225]
|
111 |
|
112 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
113 |
def build_image_encoder(config):
|
114 |
if config['type'] == 'CLIPImageEncoder':
|
115 |
return CLIPImageEncoder(**config['kwargs'])
|
116 |
elif config['type'] == 'DinoImageEncoder':
|
117 |
return DinoImageEncoder(**config['kwargs'])
|
|
|
|
|
118 |
else:
|
119 |
raise ValueError(f'Unknown image encoder type: {config["type"]}')
|
120 |
|
|
|
129 |
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
130 |
self.additional_image_encoder = build_image_encoder(additional_image_encoder)
|
131 |
|
132 |
+
def forward(self, image, mask=None):
|
133 |
outputs = {
|
134 |
+
'main': self.main_image_encoder(image, mask=mask),
|
135 |
+
'additional': self.additional_image_encoder(image, mask=mask),
|
136 |
}
|
137 |
return outputs
|
138 |
|
139 |
+
def unconditional_embedding(self, batch_size):
|
140 |
outputs = {
|
141 |
+
'main': self.main_image_encoder.unconditional_embedding(batch_size),
|
142 |
+
'additional': self.additional_image_encoder.unconditional_embedding(batch_size),
|
143 |
}
|
144 |
return outputs
|
145 |
|
|
|
152 |
super().__init__()
|
153 |
self.main_image_encoder = build_image_encoder(main_image_encoder)
|
154 |
|
155 |
+
def forward(self, image, mask=None):
|
156 |
outputs = {
|
157 |
+
'main': self.main_image_encoder(image, mask=mask),
|
158 |
}
|
159 |
return outputs
|
160 |
|
161 |
+
def unconditional_embedding(self, batch_size):
|
162 |
outputs = {
|
163 |
+
'main': self.main_image_encoder.unconditional_embedding(batch_size),
|
164 |
}
|
165 |
return outputs
|
hy3dgen/shapegen/models/denoisers/__init__.py
CHANGED
@@ -1 +1,2 @@
|
|
1 |
from .hunyuan3ddit import Hunyuan3DDiT
|
|
|
|
1 |
from .hunyuan3ddit import Hunyuan3DDiT
|
2 |
+
from .hunyuandit import HunYuanDiTPlain
|
hy3dgen/shapegen/models/denoisers/hunyuan3ddit.py
CHANGED
@@ -70,15 +70,6 @@ def timestep_embedding(t: Tensor, dim, max_period=10000, time_factor: float = 10
|
|
70 |
return embedding
|
71 |
|
72 |
|
73 |
-
class GELU(nn.Module):
|
74 |
-
def __init__(self, approximate='tanh'):
|
75 |
-
super().__init__()
|
76 |
-
self.approximate = approximate
|
77 |
-
|
78 |
-
def forward(self, x: Tensor) -> Tensor:
|
79 |
-
return nn.functional.gelu(x.contiguous(), approximate=self.approximate)
|
80 |
-
|
81 |
-
|
82 |
class MLPEmbedder(nn.Module):
|
83 |
def __init__(self, in_dim: int, hidden_dim: int):
|
84 |
super().__init__()
|
@@ -181,7 +172,7 @@ class DoubleStreamBlock(nn.Module):
|
|
181 |
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
182 |
self.img_mlp = nn.Sequential(
|
183 |
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
184 |
-
GELU(approximate="tanh"),
|
185 |
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
186 |
)
|
187 |
|
@@ -192,7 +183,7 @@ class DoubleStreamBlock(nn.Module):
|
|
192 |
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
193 |
self.txt_mlp = nn.Sequential(
|
194 |
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
195 |
-
GELU(approximate="tanh"),
|
196 |
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
197 |
)
|
198 |
|
@@ -258,7 +249,7 @@ class SingleStreamBlock(nn.Module):
|
|
258 |
self.hidden_size = hidden_size
|
259 |
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
260 |
|
261 |
-
self.mlp_act = GELU(approximate="tanh")
|
262 |
self.modulation = Modulation(hidden_size, double=False)
|
263 |
|
264 |
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
|
|
70 |
return embedding
|
71 |
|
72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
73 |
class MLPEmbedder(nn.Module):
|
74 |
def __init__(self, in_dim: int, hidden_dim: int):
|
75 |
super().__init__()
|
|
|
172 |
self.img_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
173 |
self.img_mlp = nn.Sequential(
|
174 |
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
175 |
+
nn.GELU(approximate="tanh"),
|
176 |
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
177 |
)
|
178 |
|
|
|
183 |
self.txt_norm2 = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
184 |
self.txt_mlp = nn.Sequential(
|
185 |
nn.Linear(hidden_size, mlp_hidden_dim, bias=True),
|
186 |
+
nn.GELU(approximate="tanh"),
|
187 |
nn.Linear(mlp_hidden_dim, hidden_size, bias=True),
|
188 |
)
|
189 |
|
|
|
249 |
self.hidden_size = hidden_size
|
250 |
self.pre_norm = nn.LayerNorm(hidden_size, elementwise_affine=False, eps=1e-6)
|
251 |
|
252 |
+
self.mlp_act = nn.GELU(approximate="tanh")
|
253 |
self.modulation = Modulation(hidden_size, double=False)
|
254 |
|
255 |
def forward(self, x: Tensor, vec: Tensor, pe: Tensor) -> Tensor:
|
hy3dgen/shapegen/pipelines.py
CHANGED
@@ -34,12 +34,11 @@ import trimesh
|
|
34 |
import yaml
|
35 |
from PIL import Image
|
36 |
from diffusers.utils.torch_utils import randn_tensor
|
37 |
-
from diffusers.utils.import_utils import is_accelerate_version, is_accelerate_available
|
38 |
from tqdm import tqdm
|
39 |
|
40 |
from .models.autoencoders import ShapeVAE
|
41 |
from .models.autoencoders import SurfaceExtractors
|
42 |
-
from .utils import logger, synchronize_timer
|
43 |
|
44 |
|
45 |
def retrieve_timesteps(
|
@@ -138,9 +137,6 @@ def instantiate_from_config(config, **kwargs):
|
|
138 |
|
139 |
|
140 |
class Hunyuan3DDiTPipeline:
|
141 |
-
model_cpu_offload_seq = "conditioner->model->vae"
|
142 |
-
_exclude_from_cpu_offload = []
|
143 |
-
|
144 |
@classmethod
|
145 |
@synchronize_timer('Hunyuan3DDiTPipeline Model Loading')
|
146 |
def from_single_file(
|
@@ -221,12 +217,34 @@ class Hunyuan3DDiTPipeline:
|
|
221 |
dtype=dtype,
|
222 |
device=device,
|
223 |
)
|
224 |
-
|
225 |
-
|
226 |
-
|
227 |
-
|
228 |
-
|
229 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
230 |
return cls.from_single_file(
|
231 |
ckpt_path,
|
232 |
config_path,
|
@@ -271,18 +289,12 @@ class Hunyuan3DDiTPipeline:
|
|
271 |
if enabled:
|
272 |
model_path = self.kwargs['from_pretrained_kwargs']['model_path']
|
273 |
turbo_vae_mapping = {
|
274 |
-
'Hunyuan3D-2':
|
275 |
-
'Hunyuan3D-
|
276 |
-
'Hunyuan3D-2mini': ('tencent/Hunyuan3D-2mini', 'hunyuan3d-vae-v2-mini-turbo'),
|
277 |
}
|
278 |
model_name = model_path.split('/')[-1]
|
279 |
if replace_vae and model_name in turbo_vae_mapping:
|
280 |
-
model_path, subfolder
|
281 |
-
self.vae = ShapeVAE.from_pretrained(
|
282 |
-
model_path, subfolder=subfolder,
|
283 |
-
use_safetensors=self.kwargs['from_pretrained_kwargs']['use_safetensors'],
|
284 |
-
device=self.device,
|
285 |
-
)
|
286 |
self.vae.enable_flashvdm_decoder(
|
287 |
enabled=enabled,
|
288 |
adaptive_kv_selection=adaptive_kv_selection,
|
@@ -292,146 +304,33 @@ class Hunyuan3DDiTPipeline:
|
|
292 |
else:
|
293 |
model_path = self.kwargs['from_pretrained_kwargs']['model_path']
|
294 |
vae_mapping = {
|
295 |
-
'Hunyuan3D-2':
|
296 |
-
'Hunyuan3D-
|
297 |
-
'Hunyuan3D-2mini': ('tencent/Hunyuan3D-2mini', 'hunyuan3d-vae-v2-mini'),
|
298 |
}
|
299 |
model_name = model_path.split('/')[-1]
|
300 |
if model_name in vae_mapping:
|
301 |
-
model_path, subfolder
|
302 |
-
self.vae = ShapeVAE.from_pretrained(model_path, subfolder=subfolder)
|
303 |
self.vae.enable_flashvdm_decoder(enabled=False)
|
304 |
|
305 |
def to(self, device=None, dtype=None):
|
306 |
-
if dtype is not None:
|
307 |
-
self.dtype = dtype
|
308 |
-
self.vae.to(dtype=dtype)
|
309 |
-
self.model.to(dtype=dtype)
|
310 |
-
self.conditioner.to(dtype=dtype)
|
311 |
if device is not None:
|
312 |
self.device = torch.device(device)
|
313 |
self.vae.to(device)
|
314 |
self.model.to(device)
|
315 |
self.conditioner.to(device)
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
[`~DiffusionPipeline.enable_sequential_cpu_offload`] the execution device can only be inferred from
|
322 |
-
Accelerate's module hooks.
|
323 |
-
"""
|
324 |
-
for name, model in self.components.items():
|
325 |
-
if not isinstance(model, torch.nn.Module) or name in self._exclude_from_cpu_offload:
|
326 |
-
continue
|
327 |
-
|
328 |
-
if not hasattr(model, "_hf_hook"):
|
329 |
-
return self.device
|
330 |
-
for module in model.modules():
|
331 |
-
if (
|
332 |
-
hasattr(module, "_hf_hook")
|
333 |
-
and hasattr(module._hf_hook, "execution_device")
|
334 |
-
and module._hf_hook.execution_device is not None
|
335 |
-
):
|
336 |
-
return torch.device(module._hf_hook.execution_device)
|
337 |
-
return self.device
|
338 |
-
|
339 |
-
def enable_model_cpu_offload(self, gpu_id: Optional[int] = None, device: Union[torch.device, str] = "cuda"):
|
340 |
-
r"""
|
341 |
-
Offloads all models to CPU using accelerate, reducing memory usage with a low impact on performance. Compared
|
342 |
-
to `enable_sequential_cpu_offload`, this method moves one whole model at a time to the GPU when its `forward`
|
343 |
-
method is called, and the model remains in GPU until the next model runs. Memory savings are lower than with
|
344 |
-
`enable_sequential_cpu_offload`, but performance is much better due to the iterative execution of the `unet`.
|
345 |
-
|
346 |
-
Arguments:
|
347 |
-
gpu_id (`int`, *optional*):
|
348 |
-
The ID of the accelerator that shall be used in inference. If not specified, it will default to 0.
|
349 |
-
device (`torch.Device` or `str`, *optional*, defaults to "cuda"):
|
350 |
-
The PyTorch device type of the accelerator that shall be used in inference. If not specified, it will
|
351 |
-
default to "cuda".
|
352 |
-
"""
|
353 |
-
if self.model_cpu_offload_seq is None:
|
354 |
-
raise ValueError(
|
355 |
-
"Model CPU offload cannot be enabled because no `model_cpu_offload_seq` class attribute is set."
|
356 |
-
)
|
357 |
-
|
358 |
-
if is_accelerate_available() and is_accelerate_version(">=", "0.17.0.dev0"):
|
359 |
-
from accelerate import cpu_offload_with_hook
|
360 |
-
else:
|
361 |
-
raise ImportError("`enable_model_cpu_offload` requires `accelerate v0.17.0` or higher.")
|
362 |
-
|
363 |
-
torch_device = torch.device(device)
|
364 |
-
device_index = torch_device.index
|
365 |
-
|
366 |
-
if gpu_id is not None and device_index is not None:
|
367 |
-
raise ValueError(
|
368 |
-
f"You have passed both `gpu_id`={gpu_id} and an index as part of the passed device `device`={device}"
|
369 |
-
f"Cannot pass both. Please make sure to either not define `gpu_id` or not pass the index as part of the device: `device`={torch_device.type}"
|
370 |
-
)
|
371 |
-
|
372 |
-
# _offload_gpu_id should be set to passed gpu_id (or id in passed `device`) or default to previously set id or default to 0
|
373 |
-
self._offload_gpu_id = gpu_id or torch_device.index or getattr(self, "_offload_gpu_id", 0)
|
374 |
-
|
375 |
-
device_type = torch_device.type
|
376 |
-
device = torch.device(f"{device_type}:{self._offload_gpu_id}")
|
377 |
-
|
378 |
-
if self.device.type != "cpu":
|
379 |
-
self.to("cpu")
|
380 |
-
device_mod = getattr(torch, self.device.type, None)
|
381 |
-
if hasattr(device_mod, "empty_cache") and device_mod.is_available():
|
382 |
-
device_mod.empty_cache() # otherwise we don't see the memory savings (but they probably exist)
|
383 |
-
|
384 |
-
all_model_components = {k: v for k, v in self.components.items() if isinstance(v, torch.nn.Module)}
|
385 |
-
|
386 |
-
self._all_hooks = []
|
387 |
-
hook = None
|
388 |
-
for model_str in self.model_cpu_offload_seq.split("->"):
|
389 |
-
model = all_model_components.pop(model_str, None)
|
390 |
-
if not isinstance(model, torch.nn.Module):
|
391 |
-
continue
|
392 |
-
|
393 |
-
_, hook = cpu_offload_with_hook(model, device, prev_module_hook=hook)
|
394 |
-
self._all_hooks.append(hook)
|
395 |
-
|
396 |
-
# CPU offload models that are not in the seq chain unless they are explicitly excluded
|
397 |
-
# these models will stay on CPU until maybe_free_model_hooks is called
|
398 |
-
# some models cannot be in the seq chain because they are iteratively called, such as controlnet
|
399 |
-
for name, model in all_model_components.items():
|
400 |
-
if not isinstance(model, torch.nn.Module):
|
401 |
-
continue
|
402 |
-
|
403 |
-
if name in self._exclude_from_cpu_offload:
|
404 |
-
model.to(device)
|
405 |
-
else:
|
406 |
-
_, hook = cpu_offload_with_hook(model, device)
|
407 |
-
self._all_hooks.append(hook)
|
408 |
-
|
409 |
-
def maybe_free_model_hooks(self):
|
410 |
-
r"""
|
411 |
-
Function that offloads all components, removes all model hooks that were added when using
|
412 |
-
`enable_model_cpu_offload` and then applies them again. In case the model has not been offloaded this function
|
413 |
-
is a no-op. Make sure to add this function to the end of the `__call__` function of your pipeline so that it
|
414 |
-
functions correctly when applying enable_model_cpu_offload.
|
415 |
-
"""
|
416 |
-
if not hasattr(self, "_all_hooks") or len(self._all_hooks) == 0:
|
417 |
-
# `enable_model_cpu_offload` has not be called, so silently do nothing
|
418 |
-
return
|
419 |
-
|
420 |
-
for hook in self._all_hooks:
|
421 |
-
# offload model and remove hook from model
|
422 |
-
hook.offload()
|
423 |
-
hook.remove()
|
424 |
-
|
425 |
-
# make sure the model is in the same state as before calling it
|
426 |
-
self.enable_model_cpu_offload()
|
427 |
|
428 |
@synchronize_timer('Encode cond')
|
429 |
-
def encode_cond(self, image,
|
430 |
bsz = image.shape[0]
|
431 |
-
cond = self.conditioner(image=image,
|
432 |
|
433 |
if do_classifier_free_guidance:
|
434 |
-
un_cond = self.conditioner.unconditional_embedding(bsz
|
435 |
|
436 |
if dual_guidance:
|
437 |
un_cond_drop_main = copy.deepcopy(un_cond)
|
@@ -447,7 +346,7 @@ class Hunyuan3DDiTPipeline:
|
|
447 |
|
448 |
cond = cat_recursive(cond, un_cond_drop_main, un_cond)
|
449 |
else:
|
450 |
-
un_cond = self.conditioner.unconditional_embedding(bsz
|
451 |
|
452 |
def cat_recursive(a, b):
|
453 |
if isinstance(a, torch.Tensor):
|
@@ -494,27 +393,25 @@ class Hunyuan3DDiTPipeline:
|
|
494 |
latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)
|
495 |
return latents
|
496 |
|
497 |
-
def prepare_image(self, image)
|
498 |
if isinstance(image, str) and not os.path.exists(image):
|
499 |
raise FileNotFoundError(f"Couldn't find image at path {image}")
|
500 |
|
501 |
if not isinstance(image, list):
|
502 |
image = [image]
|
503 |
-
|
504 |
-
|
505 |
for img in image:
|
506 |
-
|
507 |
-
|
508 |
-
|
509 |
-
cond_input = {k: [] for k in outputs[0].keys()}
|
510 |
-
for output in outputs:
|
511 |
-
for key, value in output.items():
|
512 |
-
cond_input[key].append(value)
|
513 |
-
for key, value in cond_input.items():
|
514 |
-
if isinstance(value[0], torch.Tensor):
|
515 |
-
cond_input[key] = torch.cat(value, dim=0)
|
516 |
|
517 |
-
|
|
|
|
|
|
|
|
|
|
|
518 |
|
519 |
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
520 |
"""
|
@@ -589,6 +486,7 @@ class Hunyuan3DDiTPipeline:
|
|
589 |
|
590 |
image, mask = self.prepare_image(image)
|
591 |
cond = self.encode_cond(image=image,
|
|
|
592 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
593 |
dual_guidance=dual_guidance)
|
594 |
batch_size = image.shape[0]
|
@@ -648,17 +546,7 @@ class Hunyuan3DDiTPipeline:
|
|
648 |
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
649 |
)
|
650 |
|
651 |
-
def _export(
|
652 |
-
self,
|
653 |
-
latents,
|
654 |
-
output_type='trimesh',
|
655 |
-
box_v=1.01,
|
656 |
-
mc_level=0.0,
|
657 |
-
num_chunks=20000,
|
658 |
-
octree_resolution=256,
|
659 |
-
mc_algo='mc',
|
660 |
-
enable_pbar=True
|
661 |
-
):
|
662 |
if not output_type == "latent":
|
663 |
latents = 1. / self.vae.scale_factor * latents
|
664 |
latents = self.vae(latents)
|
@@ -685,7 +573,7 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
|
685 |
@torch.inference_mode()
|
686 |
def __call__(
|
687 |
self,
|
688 |
-
image: Union[str, List[str], Image.Image
|
689 |
num_inference_steps: int = 50,
|
690 |
timesteps: List[int] = None,
|
691 |
sigmas: List[float] = None,
|
@@ -713,11 +601,10 @@ class Hunyuan3DDiTFlowMatchingPipeline(Hunyuan3DDiTPipeline):
|
|
713 |
self.model.guidance_embed is True
|
714 |
)
|
715 |
|
716 |
-
|
717 |
-
image = cond_inputs.pop('image')
|
718 |
cond = self.encode_cond(
|
719 |
image=image,
|
720 |
-
|
721 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
722 |
dual_guidance=False,
|
723 |
)
|
|
|
34 |
import yaml
|
35 |
from PIL import Image
|
36 |
from diffusers.utils.torch_utils import randn_tensor
|
|
|
37 |
from tqdm import tqdm
|
38 |
|
39 |
from .models.autoencoders import ShapeVAE
|
40 |
from .models.autoencoders import SurfaceExtractors
|
41 |
+
from .utils import logger, synchronize_timer
|
42 |
|
43 |
|
44 |
def retrieve_timesteps(
|
|
|
137 |
|
138 |
|
139 |
class Hunyuan3DDiTPipeline:
|
|
|
|
|
|
|
140 |
@classmethod
|
141 |
@synchronize_timer('Hunyuan3DDiTPipeline Model Loading')
|
142 |
def from_single_file(
|
|
|
217 |
dtype=dtype,
|
218 |
device=device,
|
219 |
)
|
220 |
+
original_model_path = model_path
|
221 |
+
# try local path
|
222 |
+
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
|
223 |
+
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
|
224 |
+
logger.info(f'Try to load model from local path: {model_path}')
|
225 |
+
if not os.path.exists(model_path):
|
226 |
+
logger.info('Model path not exists, try to download from huggingface')
|
227 |
+
try:
|
228 |
+
import huggingface_hub
|
229 |
+
# download from huggingface
|
230 |
+
path = huggingface_hub.snapshot_download(repo_id=original_model_path)
|
231 |
+
model_path = os.path.join(path, subfolder)
|
232 |
+
except ImportError:
|
233 |
+
logger.warning(
|
234 |
+
"You need to install HuggingFace Hub to load models from the hub."
|
235 |
+
)
|
236 |
+
raise RuntimeError(f"Model path {model_path} not found")
|
237 |
+
except Exception as e:
|
238 |
+
raise e
|
239 |
+
|
240 |
+
if not os.path.exists(model_path):
|
241 |
+
raise FileNotFoundError(f"Model path {original_model_path} not found")
|
242 |
+
|
243 |
+
extension = 'ckpt' if not use_safetensors else 'safetensors'
|
244 |
+
variant = '' if variant is None else f'.{variant}'
|
245 |
+
ckpt_name = f'model{variant}.{extension}'
|
246 |
+
config_path = os.path.join(model_path, 'config.yaml')
|
247 |
+
ckpt_path = os.path.join(model_path, ckpt_name)
|
248 |
return cls.from_single_file(
|
249 |
ckpt_path,
|
250 |
config_path,
|
|
|
289 |
if enabled:
|
290 |
model_path = self.kwargs['from_pretrained_kwargs']['model_path']
|
291 |
turbo_vae_mapping = {
|
292 |
+
'Hunyuan3D-2': 'hunyuan3d-vae-v2-0-turbo',
|
293 |
+
'Hunyuan3D-2s': 'hunyuan3d-vae-v2-s-turbo'
|
|
|
294 |
}
|
295 |
model_name = model_path.split('/')[-1]
|
296 |
if replace_vae and model_name in turbo_vae_mapping:
|
297 |
+
self.vae = ShapeVAE.from_pretrained(model_path, subfolder=turbo_vae_mapping[model_name])
|
|
|
|
|
|
|
|
|
|
|
298 |
self.vae.enable_flashvdm_decoder(
|
299 |
enabled=enabled,
|
300 |
adaptive_kv_selection=adaptive_kv_selection,
|
|
|
304 |
else:
|
305 |
model_path = self.kwargs['from_pretrained_kwargs']['model_path']
|
306 |
vae_mapping = {
|
307 |
+
'Hunyuan3D-2': 'hunyuan3d-vae-v2-0',
|
308 |
+
'Hunyuan3D-2s': 'hunyuan3d-vae-v2-s'
|
|
|
309 |
}
|
310 |
model_name = model_path.split('/')[-1]
|
311 |
if model_name in vae_mapping:
|
312 |
+
self.vae = ShapeVAE.from_pretrained(model_path, subfolder=vae_mapping[model_name])
|
|
|
313 |
self.vae.enable_flashvdm_decoder(enabled=False)
|
314 |
|
315 |
def to(self, device=None, dtype=None):
|
|
|
|
|
|
|
|
|
|
|
316 |
if device is not None:
|
317 |
self.device = torch.device(device)
|
318 |
self.vae.to(device)
|
319 |
self.model.to(device)
|
320 |
self.conditioner.to(device)
|
321 |
+
if dtype is not None:
|
322 |
+
self.dtype = dtype
|
323 |
+
self.vae.to(dtype=dtype)
|
324 |
+
self.model.to(dtype=dtype)
|
325 |
+
self.conditioner.to(dtype=dtype)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
326 |
|
327 |
@synchronize_timer('Encode cond')
|
328 |
+
def encode_cond(self, image, mask, do_classifier_free_guidance, dual_guidance):
|
329 |
bsz = image.shape[0]
|
330 |
+
cond = self.conditioner(image=image, mask=mask)
|
331 |
|
332 |
if do_classifier_free_guidance:
|
333 |
+
un_cond = self.conditioner.unconditional_embedding(bsz)
|
334 |
|
335 |
if dual_guidance:
|
336 |
un_cond_drop_main = copy.deepcopy(un_cond)
|
|
|
346 |
|
347 |
cond = cat_recursive(cond, un_cond_drop_main, un_cond)
|
348 |
else:
|
349 |
+
un_cond = self.conditioner.unconditional_embedding(bsz)
|
350 |
|
351 |
def cat_recursive(a, b):
|
352 |
if isinstance(a, torch.Tensor):
|
|
|
393 |
latents = latents * getattr(self.scheduler, 'init_noise_sigma', 1.0)
|
394 |
return latents
|
395 |
|
396 |
+
def prepare_image(self, image):
|
397 |
if isinstance(image, str) and not os.path.exists(image):
|
398 |
raise FileNotFoundError(f"Couldn't find image at path {image}")
|
399 |
|
400 |
if not isinstance(image, list):
|
401 |
image = [image]
|
402 |
+
image_pts = []
|
403 |
+
mask_pts = []
|
404 |
for img in image:
|
405 |
+
image_pt, mask_pt = self.image_processor(img, return_mask=True)
|
406 |
+
image_pts.append(image_pt)
|
407 |
+
mask_pts.append(mask_pt)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
408 |
|
409 |
+
image_pts = torch.cat(image_pts, dim=0).to(self.device, dtype=self.dtype)
|
410 |
+
if mask_pts[0] is not None:
|
411 |
+
mask_pts = torch.cat(mask_pts, dim=0).to(self.device, dtype=self.dtype)
|
412 |
+
else:
|
413 |
+
mask_pts = None
|
414 |
+
return image_pts, mask_pts
|
415 |
|
416 |
def get_guidance_scale_embedding(self, w, embedding_dim=512, dtype=torch.float32):
|
417 |
"""
|
|
|
486 |
|
487 |
image, mask = self.prepare_image(image)
|
488 |
cond = self.encode_cond(image=image,
|
489 |
+
mask=mask,
|
490 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
491 |
dual_guidance=dual_guidance)
|
492 |
batch_size = image.shape[0]
|
|
|
546 |
box_v, mc_level, num_chunks, octree_resolution, mc_algo,
|
547 |
)
|
548 |
|
549 |
+
def _export(self, latents, output_type, box_v, mc_level, num_chunks, octree_resolution, mc_algo, enable_pbar=True):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
550 |
if not output_type == "latent":
|
551 |
latents = 1. / self.vae.scale_factor * latents
|
552 |
latents = self.vae(latents)
|
|
|
573 |
@torch.inference_mode()
|
574 |
def __call__(
|
575 |
self,
|
576 |
+
image: Union[str, List[str], Image.Image] = None,
|
577 |
num_inference_steps: int = 50,
|
578 |
timesteps: List[int] = None,
|
579 |
sigmas: List[float] = None,
|
|
|
601 |
self.model.guidance_embed is True
|
602 |
)
|
603 |
|
604 |
+
image, mask = self.prepare_image(image)
|
|
|
605 |
cond = self.encode_cond(
|
606 |
image=image,
|
607 |
+
mask=mask,
|
608 |
do_classifier_free_guidance=do_classifier_free_guidance,
|
609 |
dual_guidance=False,
|
610 |
)
|
hy3dgen/shapegen/postprocessors.py
CHANGED
@@ -22,16 +22,13 @@
|
|
22 |
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
23 |
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
|
25 |
-
import os
|
26 |
import tempfile
|
27 |
from typing import Union
|
28 |
|
29 |
-
import numpy as np
|
30 |
import pymeshlab
|
31 |
-
import torch
|
32 |
import trimesh
|
33 |
|
34 |
-
from .models.
|
35 |
from .utils import synchronize_timer
|
36 |
|
37 |
|
@@ -165,62 +162,3 @@ class DegenerateFaceRemover:
|
|
165 |
|
166 |
mesh = export_mesh(mesh, ms)
|
167 |
return mesh
|
168 |
-
|
169 |
-
|
170 |
-
def import_pymeshlab_mesh(mesh: Union[pymeshlab.MeshSet, trimesh.Trimesh, Latent2MeshOutput, str]) -> pymeshlab.MeshSet:
|
171 |
-
if isinstance(mesh, str):
|
172 |
-
mesh = load_mesh(mesh)
|
173 |
-
elif isinstance(mesh, Latent2MeshOutput):
|
174 |
-
mesh = pymeshlab.MeshSet()
|
175 |
-
mesh_pymeshlab = pymeshlab.Mesh(vertex_matrix=mesh.mesh_v, face_matrix=mesh.mesh_f)
|
176 |
-
mesh.add_mesh(mesh_pymeshlab, "converted_mesh")
|
177 |
-
|
178 |
-
if isinstance(mesh, (trimesh.Trimesh, trimesh.scene.Scene)):
|
179 |
-
mesh = trimesh2pymeshlab(mesh)
|
180 |
-
|
181 |
-
return mesh
|
182 |
-
|
183 |
-
|
184 |
-
def mesh_normalize(mesh):
|
185 |
-
"""
|
186 |
-
Normalize mesh vertices to sphere
|
187 |
-
"""
|
188 |
-
scale_factor = 1.2
|
189 |
-
vtx_pos = np.asarray(mesh.vertices)
|
190 |
-
max_bb = (vtx_pos - 0).max(0)[0]
|
191 |
-
min_bb = (vtx_pos - 0).min(0)[0]
|
192 |
-
|
193 |
-
center = (max_bb + min_bb) / 2
|
194 |
-
|
195 |
-
scale = torch.norm(torch.tensor(vtx_pos - center, dtype=torch.float32), dim=1).max() * 2.0
|
196 |
-
|
197 |
-
vtx_pos = (vtx_pos - center) * (scale_factor / float(scale))
|
198 |
-
mesh.vertices = vtx_pos
|
199 |
-
|
200 |
-
return mesh
|
201 |
-
|
202 |
-
|
203 |
-
class MeshSimplifier:
|
204 |
-
def __init__(self, executable: str = None):
|
205 |
-
if executable is None:
|
206 |
-
CURRENT_DIR = os.path.dirname(os.path.abspath(__file__))
|
207 |
-
executable = os.path.join(CURRENT_DIR, "mesh_simplifier.bin")
|
208 |
-
self.executable = executable
|
209 |
-
|
210 |
-
@synchronize_timer('MeshSimplifier')
|
211 |
-
def __call__(
|
212 |
-
self,
|
213 |
-
mesh: Union[trimesh.Trimesh],
|
214 |
-
) -> Union[trimesh.Trimesh]:
|
215 |
-
with tempfile.NamedTemporaryFile(suffix='.obj', delete=True) as temp_input:
|
216 |
-
with tempfile.NamedTemporaryFile(suffix='.obj', delete=True) as temp_output:
|
217 |
-
mesh.export(temp_input.name)
|
218 |
-
os.system(f'{self.executable} {temp_input.name} {temp_output.name}')
|
219 |
-
ms = trimesh.load(temp_output.name, process=False)
|
220 |
-
if isinstance(ms, trimesh.Scene):
|
221 |
-
combined_mesh = trimesh.Trimesh()
|
222 |
-
for geom in ms.geometry.values():
|
223 |
-
combined_mesh = trimesh.util.concatenate([combined_mesh, geom])
|
224 |
-
ms = combined_mesh
|
225 |
-
ms = mesh_normalize(ms)
|
226 |
-
return ms
|
|
|
22 |
# fine-tuning enabling code and other elements of the foregoing made publicly available
|
23 |
# by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT.
|
24 |
|
|
|
25 |
import tempfile
|
26 |
from typing import Union
|
27 |
|
|
|
28 |
import pymeshlab
|
|
|
29 |
import trimesh
|
30 |
|
31 |
+
from .models.vae import Latent2MeshOutput
|
32 |
from .utils import synchronize_timer
|
33 |
|
34 |
|
|
|
162 |
|
163 |
mesh = export_mesh(mesh, ms)
|
164 |
return mesh
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hy3dgen/shapegen/preprocessors.py
CHANGED
@@ -96,7 +96,7 @@ class ImageProcessorV2:
|
|
96 |
mask = mask.clip(0, 255).astype(np.uint8)
|
97 |
return result, mask
|
98 |
|
99 |
-
def __call__(self, image, border_ratio=0.15, to_tensor=True, **kwargs):
|
100 |
if self.border_ratio is not None:
|
101 |
border_ratio = self.border_ratio
|
102 |
if isinstance(image, str):
|
@@ -115,74 +115,13 @@ class ImageProcessorV2:
|
|
115 |
if to_tensor:
|
116 |
image = array_to_tensor(image)
|
117 |
mask = array_to_tensor(mask)
|
118 |
-
|
119 |
-
|
120 |
-
|
121 |
-
'mask': mask
|
122 |
-
}
|
123 |
-
return outputs
|
124 |
-
|
125 |
-
|
126 |
-
class MVImageProcessorV2(ImageProcessorV2):
|
127 |
-
"""
|
128 |
-
view order: front, front clockwise 90, back, front clockwise 270
|
129 |
-
"""
|
130 |
-
return_view_idx = True
|
131 |
-
|
132 |
-
def __init__(self, size=512, border_ratio=None):
|
133 |
-
super().__init__(size, border_ratio)
|
134 |
-
self.view2idx = {
|
135 |
-
'front': 0,
|
136 |
-
'left': 1,
|
137 |
-
'back': 2,
|
138 |
-
'right': 3
|
139 |
-
}
|
140 |
-
|
141 |
-
def __call__(self, image_dict, border_ratio=0.15, to_tensor=True, **kwargs):
|
142 |
-
if self.border_ratio is not None:
|
143 |
-
border_ratio = self.border_ratio
|
144 |
-
|
145 |
-
images = []
|
146 |
-
masks = []
|
147 |
-
view_idxs = []
|
148 |
-
for idx, (view_tag, image) in enumerate(image_dict.items()):
|
149 |
-
view_idxs.append(self.view2idx[view_tag])
|
150 |
-
if isinstance(image, str):
|
151 |
-
image = cv2.imread(image, cv2.IMREAD_UNCHANGED)
|
152 |
-
image, mask = self.recenter(image, border_ratio=border_ratio)
|
153 |
-
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
154 |
-
elif isinstance(image, Image.Image):
|
155 |
-
image = image.convert("RGBA")
|
156 |
-
image = np.asarray(image)
|
157 |
-
image, mask = self.recenter(image, border_ratio=border_ratio)
|
158 |
-
|
159 |
-
image = cv2.resize(image, (self.size, self.size), interpolation=cv2.INTER_CUBIC)
|
160 |
-
mask = cv2.resize(mask, (self.size, self.size), interpolation=cv2.INTER_NEAREST)
|
161 |
-
mask = mask[..., np.newaxis]
|
162 |
-
|
163 |
-
if to_tensor:
|
164 |
-
image = array_to_tensor(image)
|
165 |
-
mask = array_to_tensor(mask)
|
166 |
-
images.append(image)
|
167 |
-
masks.append(mask)
|
168 |
-
|
169 |
-
zipped_lists = zip(view_idxs, images, masks)
|
170 |
-
sorted_zipped_lists = sorted(zipped_lists)
|
171 |
-
view_idxs, images, masks = zip(*sorted_zipped_lists)
|
172 |
-
|
173 |
-
image = torch.cat(images, 0).unsqueeze(0)
|
174 |
-
mask = torch.cat(masks, 0).unsqueeze(0)
|
175 |
-
outputs = {
|
176 |
-
'image': image,
|
177 |
-
'mask': mask,
|
178 |
-
'view_idxs': view_idxs
|
179 |
-
}
|
180 |
-
return outputs
|
181 |
|
182 |
|
183 |
IMAGE_PROCESSORS = {
|
184 |
"v2": ImageProcessorV2,
|
185 |
-
'mv_v2': MVImageProcessorV2,
|
186 |
}
|
187 |
|
188 |
DEFAULT_IMAGEPROCESSOR = 'v2'
|
|
|
96 |
mask = mask.clip(0, 255).astype(np.uint8)
|
97 |
return result, mask
|
98 |
|
99 |
+
def __call__(self, image, border_ratio=0.15, to_tensor=True, return_mask=False, **kwargs):
|
100 |
if self.border_ratio is not None:
|
101 |
border_ratio = self.border_ratio
|
102 |
if isinstance(image, str):
|
|
|
115 |
if to_tensor:
|
116 |
image = array_to_tensor(image)
|
117 |
mask = array_to_tensor(mask)
|
118 |
+
if return_mask:
|
119 |
+
return image, mask
|
120 |
+
return image
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
|
123 |
IMAGE_PROCESSORS = {
|
124 |
"v2": ImageProcessorV2,
|
|
|
125 |
}
|
126 |
|
127 |
DEFAULT_IMAGEPROCESSOR = 'v2'
|
hy3dgen/shapegen/utils.py
CHANGED
@@ -70,40 +70,3 @@ class synchronize_timer:
|
|
70 |
return result
|
71 |
|
72 |
return wrapper
|
73 |
-
|
74 |
-
|
75 |
-
def smart_load_model(
|
76 |
-
model_path,
|
77 |
-
subfolder,
|
78 |
-
use_safetensors,
|
79 |
-
variant,
|
80 |
-
):
|
81 |
-
original_model_path = model_path
|
82 |
-
# try local path
|
83 |
-
base_dir = os.environ.get('HY3DGEN_MODELS', '~/.cache/hy3dgen')
|
84 |
-
model_path = os.path.expanduser(os.path.join(base_dir, model_path, subfolder))
|
85 |
-
logger.info(f'Try to load model from local path: {model_path}')
|
86 |
-
if not os.path.exists(model_path):
|
87 |
-
logger.info('Model path not exists, try to download from huggingface')
|
88 |
-
try:
|
89 |
-
import huggingface_hub
|
90 |
-
# download from huggingface
|
91 |
-
path = huggingface_hub.snapshot_download(repo_id=original_model_path)
|
92 |
-
model_path = os.path.join(path, subfolder)
|
93 |
-
except ImportError:
|
94 |
-
logger.warning(
|
95 |
-
"You need to install HuggingFace Hub to load models from the hub."
|
96 |
-
)
|
97 |
-
raise RuntimeError(f"Model path {model_path} not found")
|
98 |
-
except Exception as e:
|
99 |
-
raise e
|
100 |
-
|
101 |
-
if not os.path.exists(model_path):
|
102 |
-
raise FileNotFoundError(f"Model path {original_model_path} not found")
|
103 |
-
|
104 |
-
extension = 'ckpt' if not use_safetensors else 'safetensors'
|
105 |
-
variant = '' if variant is None else f'.{variant}'
|
106 |
-
ckpt_name = f'model{variant}.{extension}'
|
107 |
-
config_path = os.path.join(model_path, 'config.yaml')
|
108 |
-
ckpt_path = os.path.join(model_path, ckpt_name)
|
109 |
-
return config_path, ckpt_path
|
|
|
70 |
return result
|
71 |
|
72 |
return wrapper
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|