# Hunyuan 3D is licensed under the TENCENT HUNYUAN NON-COMMERCIAL LICENSE AGREEMENT # except for the third-party components listed below. # Hunyuan 3D does not impose any additional limitations beyond what is outlined # in the repsective licenses of these third-party components. # Users must comply with all terms and conditions of original licenses of these third-party # components and must ensure that the usage of the third party components adheres to # all relevant laws and regulations. # For avoidance of doubts, Hunyuan 3D means the large language models and # their software and algorithms, including trained model weights, parameters (including # optimizer states), machine-learning model code, inference-enabling code, training-enabling code, # fine-tuning enabling code and other elements of the foregoing made publicly available # by Tencent in accordance with TENCENT HUNYUAN COMMUNITY LICENSE AGREEMENT. from typing import Union, Tuple, List, Callable import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from einops import repeat from tqdm import tqdm from .attention_blocks import CrossAttentionDecoder from .attention_processors import FlashVDMCrossAttentionProcessor, FlashVDMTopMCrossAttentionProcessor from ...utils import logger def extract_near_surface_volume_fn(input_tensor: torch.Tensor, alpha: float): device = input_tensor.device D = input_tensor.shape[0] signed_val = 0.0 # 添加偏移并处理无效值 val = input_tensor + alpha valid_mask = val > -9000 # 假设-9000是无效值 # 改进的邻居获取函数(保持维度一致) def get_neighbor(t, shift, axis): """根据指定轴进行位移并保持维度一致""" if shift == 0: return t.clone() # 确定填充轴(输入为[D, D, D]对应z,y,x轴) pad_dims = [0, 0, 0, 0, 0, 0] # 格式:[x前,x后,y前,y后,z前,z后] # 根据轴类型设置填充 if axis == 0: # x轴(最后一个维度) pad_idx = 0 if shift > 0 else 1 pad_dims[pad_idx] = abs(shift) elif axis == 1: # y轴(中间维度) pad_idx = 2 if shift > 0 else 3 pad_dims[pad_idx] = abs(shift) elif axis == 2: # z轴(第一个维度) pad_idx = 4 if shift > 0 else 5 pad_dims[pad_idx] = abs(shift) # 执行填充(添加batch和channel维度适配F.pad) padded = F.pad(t.unsqueeze(0).unsqueeze(0), pad_dims[::-1], mode='replicate') # 反转顺序适配F.pad # 构建动态切片索引 slice_dims = [slice(None)] * 3 # 初始化为全切片 if axis == 0: # x轴(dim=2) if shift > 0: slice_dims[0] = slice(shift, None) else: slice_dims[0] = slice(None, shift) elif axis == 1: # y轴(dim=1) if shift > 0: slice_dims[1] = slice(shift, None) else: slice_dims[1] = slice(None, shift) elif axis == 2: # z轴(dim=0) if shift > 0: slice_dims[2] = slice(shift, None) else: slice_dims[2] = slice(None, shift) # 应用切片并恢复维度 padded = padded.squeeze(0).squeeze(0) sliced = padded[slice_dims] return sliced # 获取各方向邻居(确保维度一致) left = get_neighbor(val, 1, axis=0) # x方向 right = get_neighbor(val, -1, axis=0) back = get_neighbor(val, 1, axis=1) # y方向 front = get_neighbor(val, -1, axis=1) down = get_neighbor(val, 1, axis=2) # z方向 up = get_neighbor(val, -1, axis=2) # 处理边界无效值(使用where保持维度一致) def safe_where(neighbor): return torch.where(neighbor > -9000, neighbor, val) left = safe_where(left) right = safe_where(right) back = safe_where(back) front = safe_where(front) down = safe_where(down) up = safe_where(up) # 计算符号一致性(转换为float32确保精度) sign = torch.sign(val.to(torch.float32)) neighbors_sign = torch.stack([ torch.sign(left.to(torch.float32)), torch.sign(right.to(torch.float32)), torch.sign(back.to(torch.float32)), torch.sign(front.to(torch.float32)), torch.sign(down.to(torch.float32)), torch.sign(up.to(torch.float32)) ], dim=0) # 检查所有符号是否一致 same_sign = torch.all(neighbors_sign == sign, dim=0) # 生成最终掩码 mask = (~same_sign).to(torch.int32) return mask * valid_mask.to(torch.int32) def generate_dense_grid_points( bbox_min: np.ndarray, bbox_max: np.ndarray, octree_resolution: int, indexing: str = "ij", ): length = bbox_max - bbox_min num_cells = octree_resolution x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) [xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) xyz = np.stack((xs, ys, zs), axis=-1) grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] return xyz, grid_size, length class VanillaVolumeDecoder: @torch.no_grad() def __call__( self, latents: torch.FloatTensor, geo_decoder: Callable, bounds: Union[Tuple[float], List[float], float] = 1.01, num_chunks: int = 10000, octree_resolution: int = None, enable_pbar: bool = True, **kwargs, ): device = latents.device dtype = latents.dtype batch_size = latents.shape[0] # 1. generate query points if isinstance(bounds, float): bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) xyz_samples, grid_size, length = generate_dense_grid_points( bbox_min=bbox_min, bbox_max=bbox_max, octree_resolution=octree_resolution, indexing="ij" ) xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) # 2. latents to 3d volume batch_logits = [] for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding", disable=not enable_pbar): chunk_queries = xyz_samples[start: start + num_chunks, :] chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) logits = geo_decoder(queries=chunk_queries, latents=latents) batch_logits.append(logits) grid_logits = torch.cat(batch_logits, dim=1) grid_logits = grid_logits.view((batch_size, *grid_size)).float() return grid_logits class HierarchicalVolumeDecoding: @torch.no_grad() def __call__( self, latents: torch.FloatTensor, geo_decoder: Callable, bounds: Union[Tuple[float], List[float], float] = 1.01, num_chunks: int = 10000, mc_level: float = 0.0, octree_resolution: int = None, min_resolution: int = 63, enable_pbar: bool = True, **kwargs, ): device = latents.device dtype = latents.dtype resolutions = [] if octree_resolution < min_resolution: resolutions.append(octree_resolution) while octree_resolution >= min_resolution: resolutions.append(octree_resolution) octree_resolution = octree_resolution // 2 resolutions.reverse() # 1. generate query points if isinstance(bounds, float): bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] bbox_min = np.array(bounds[0:3]) bbox_max = np.array(bounds[3:6]) bbox_size = bbox_max - bbox_min xyz_samples, grid_size, length = generate_dense_grid_points( bbox_min=bbox_min, bbox_max=bbox_max, octree_resolution=resolutions[0], indexing="ij" ) dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype) dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device)) grid_size = np.array(grid_size) xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) # 2. latents to 3d volume batch_logits = [] batch_size = latents.shape[0] for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Hierarchical Volume Decoding [r{resolutions[0] + 1}]"): queries = xyz_samples[start: start + num_chunks, :] batch_queries = repeat(queries, "p c -> b p c", b=batch_size) logits = geo_decoder(queries=batch_queries, latents=latents) batch_logits.append(logits) grid_logits = torch.cat(batch_logits, dim=1).view((batch_size, grid_size[0], grid_size[1], grid_size[2])) for octree_depth_now in resolutions[1:]: grid_size = np.array([octree_depth_now + 1] * 3) resolution = bbox_size / octree_depth_now next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device) next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device) curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level) curr_points += grid_logits.squeeze(0).abs() < 0.95 if octree_depth_now == resolutions[-1]: expand_num = 0 else: expand_num = 1 for i in range(expand_num): curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0) (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0) next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 for i in range(2 - expand_num): next_index = dilate(next_index.unsqueeze(0)).squeeze(0) nidx = torch.where(next_index > 0) next_points = torch.stack(nidx, dim=1) next_points = (next_points * torch.tensor(resolution, dtype=next_points.dtype, device=device) + torch.tensor(bbox_min, dtype=next_points.dtype, device=device)) batch_logits = [] for start in tqdm(range(0, next_points.shape[0], num_chunks), desc=f"Hierarchical Volume Decoding [r{octree_depth_now + 1}]"): queries = next_points[start: start + num_chunks, :] batch_queries = repeat(queries, "p c -> b p c", b=batch_size) logits = geo_decoder(queries=batch_queries.to(latents.dtype), latents=latents) batch_logits.append(logits) grid_logits = torch.cat(batch_logits, dim=1) next_logits[nidx] = grid_logits[0, ..., 0] grid_logits = next_logits.unsqueeze(0) grid_logits[grid_logits == -10000.] = float('nan') return grid_logits class FlashVDMVolumeDecoding: def __init__(self, topk_mode='mean'): if topk_mode not in ['mean', 'merge']: raise ValueError(f'Unsupported topk_mode {topk_mode}, available: {["mean", "merge"]}') if topk_mode == 'mean': self.processor = FlashVDMCrossAttentionProcessor() else: self.processor = FlashVDMTopMCrossAttentionProcessor() @torch.no_grad() def __call__( self, latents: torch.FloatTensor, geo_decoder: CrossAttentionDecoder, bounds: Union[Tuple[float], List[float], float] = 1.01, num_chunks: int = 10000, mc_level: float = 0.0, octree_resolution: int = None, min_resolution: int = 63, mini_grid_num: int = 4, enable_pbar: bool = True, **kwargs, ): processor = self.processor geo_decoder.set_cross_attention_processor(processor) device = latents.device dtype = latents.dtype resolutions = [] if octree_resolution < min_resolution: resolutions.append(octree_resolution) while octree_resolution >= min_resolution: resolutions.append(octree_resolution) octree_resolution = octree_resolution // 2 resolutions.reverse() resolutions[0] = round(resolutions[0] / mini_grid_num) * mini_grid_num - 1 for i, resolution in enumerate(resolutions[1:]): resolutions[i + 1] = resolutions[0] * 2 ** (i + 1) logger.info(f"FlashVDMVolumeDecoding Resolution: {resolutions}") # 1. generate query points if isinstance(bounds, float): bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] bbox_min = np.array(bounds[0:3]) bbox_max = np.array(bounds[3:6]) bbox_size = bbox_max - bbox_min xyz_samples, grid_size, length = generate_dense_grid_points( bbox_min=bbox_min, bbox_max=bbox_max, octree_resolution=resolutions[0], indexing="ij" ) dilate = nn.Conv3d(1, 1, 3, padding=1, bias=False, device=device, dtype=dtype) dilate.weight = torch.nn.Parameter(torch.ones(dilate.weight.shape, dtype=dtype, device=device)) grid_size = np.array(grid_size) # 2. latents to 3d volume xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype) batch_size = latents.shape[0] mini_grid_size = xyz_samples.shape[0] // mini_grid_num xyz_samples = xyz_samples.view( mini_grid_num, mini_grid_size, mini_grid_num, mini_grid_size, mini_grid_num, mini_grid_size, 3 ).permute( 0, 2, 4, 1, 3, 5, 6 ).reshape( -1, mini_grid_size * mini_grid_size * mini_grid_size, 3 ) batch_logits = [] num_batchs = max(num_chunks // xyz_samples.shape[1], 1) for start in tqdm(range(0, xyz_samples.shape[0], num_batchs), desc=f"FlashVDM Volume Decoding", disable=not enable_pbar): queries = xyz_samples[start: start + num_batchs, :] batch = queries.shape[0] batch_latents = repeat(latents.squeeze(0), "p c -> b p c", b=batch) processor.topk = True logits = geo_decoder(queries=queries, latents=batch_latents) batch_logits.append(logits) grid_logits = torch.cat(batch_logits, dim=0).reshape( mini_grid_num, mini_grid_num, mini_grid_num, mini_grid_size, mini_grid_size, mini_grid_size ).permute(0, 3, 1, 4, 2, 5).contiguous().view( (batch_size, grid_size[0], grid_size[1], grid_size[2]) ) for octree_depth_now in resolutions[1:]: grid_size = np.array([octree_depth_now + 1] * 3) resolution = bbox_size / octree_depth_now next_index = torch.zeros(tuple(grid_size), dtype=dtype, device=device) next_logits = torch.full(next_index.shape, -10000., dtype=dtype, device=device) curr_points = extract_near_surface_volume_fn(grid_logits.squeeze(0), mc_level) curr_points += grid_logits.squeeze(0).abs() < 0.95 if octree_depth_now == resolutions[-1]: expand_num = 0 else: expand_num = 1 for i in range(expand_num): curr_points = dilate(curr_points.unsqueeze(0).to(dtype)).squeeze(0) (cidx_x, cidx_y, cidx_z) = torch.where(curr_points > 0) next_index[cidx_x * 2, cidx_y * 2, cidx_z * 2] = 1 for i in range(2 - expand_num): next_index = dilate(next_index.unsqueeze(0)).squeeze(0) nidx = torch.where(next_index > 0) next_points = torch.stack(nidx, dim=1) next_points = (next_points * torch.tensor(resolution, dtype=torch.float32, device=device) + torch.tensor(bbox_min, dtype=torch.float32, device=device)) query_grid_num = 6 min_val = next_points.min(axis=0).values max_val = next_points.max(axis=0).values vol_queries_index = (next_points - min_val) / (max_val - min_val) * (query_grid_num - 0.001) index = torch.floor(vol_queries_index).long() index = index[..., 0] * (query_grid_num ** 2) + index[..., 1] * query_grid_num + index[..., 2] index = index.sort() next_points = next_points[index.indices].unsqueeze(0).contiguous() unique_values = torch.unique(index.values, return_counts=True) grid_logits = torch.zeros((next_points.shape[1]), dtype=latents.dtype, device=latents.device) input_grid = [[], []] logits_grid_list = [] start_num = 0 sum_num = 0 for grid_index, count in zip(unique_values[0].cpu().tolist(), unique_values[1].cpu().tolist()): if sum_num + count < num_chunks or sum_num == 0: sum_num += count input_grid[0].append(grid_index) input_grid[1].append(count) else: processor.topk = input_grid logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents) start_num = start_num + sum_num logits_grid_list.append(logits_grid) input_grid = [[grid_index], [count]] sum_num = count if sum_num > 0: processor.topk = input_grid logits_grid = geo_decoder(queries=next_points[:, start_num:start_num + sum_num], latents=latents) logits_grid_list.append(logits_grid) logits_grid = torch.cat(logits_grid_list, dim=1) grid_logits[index.indices] = logits_grid.squeeze(0).squeeze(-1) next_logits[nidx] = grid_logits grid_logits = next_logits.unsqueeze(0) grid_logits[grid_logits == -10000.] = float('nan') return grid_logits