Spaces:
Running
on
Zero
Running
on
Zero
from typing import Union, Tuple, List, Callable | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from einops import repeat | |
from tqdm import tqdm | |
from .attention_blocks import CrossAttentionDecoder | |
from ...utils import logger | |
def generate_dense_grid_points( | |
bbox_min: np.ndarray, | |
bbox_max: np.ndarray, | |
octree_resolution: int, | |
indexing: str = "ij", | |
): | |
length = bbox_max - bbox_min | |
num_cells = octree_resolution | |
x = np.linspace(bbox_min[0], bbox_max[0], int(num_cells) + 1, dtype=np.float32) | |
y = np.linspace(bbox_min[1], bbox_max[1], int(num_cells) + 1, dtype=np.float32) | |
z = np.linspace(bbox_min[2], bbox_max[2], int(num_cells) + 1, dtype=np.float32) | |
[xs, ys, zs] = np.meshgrid(x, y, z, indexing=indexing) | |
xyz = np.stack((xs, ys, zs), axis=-1) | |
grid_size = [int(num_cells) + 1, int(num_cells) + 1, int(num_cells) + 1] | |
return xyz, grid_size, length | |
class VanillaVolumeDecoder: | |
def __call__( | |
self, | |
latents: torch.FloatTensor, | |
geo_decoder: Callable, | |
bounds: Union[Tuple[float], List[float], float] = 1.01, | |
num_chunks: int = 10000, | |
octree_resolution: int = None, | |
enable_pbar: bool = True, | |
**kwargs, | |
): | |
device = latents.device | |
dtype = latents.dtype | |
batch_size = latents.shape[0] | |
# 1. generate query points | |
if isinstance(bounds, float): | |
bounds = [-bounds, -bounds, -bounds, bounds, bounds, bounds] | |
bbox_min, bbox_max = np.array(bounds[0:3]), np.array(bounds[3:6]) | |
xyz_samples, grid_size, length = generate_dense_grid_points( | |
bbox_min=bbox_min, | |
bbox_max=bbox_max, | |
octree_resolution=octree_resolution, | |
indexing="ij" | |
) | |
xyz_samples = torch.from_numpy(xyz_samples).to(device, dtype=dtype).contiguous().reshape(-1, 3) | |
# 2. latents to 3d volume | |
batch_logits = [] | |
for start in tqdm(range(0, xyz_samples.shape[0], num_chunks), desc=f"Volume Decoding", | |
disable=not enable_pbar): | |
chunk_queries = xyz_samples[start: start + num_chunks, :] | |
chunk_queries = repeat(chunk_queries, "p c -> b p c", b=batch_size) | |
logits = geo_decoder(queries=chunk_queries, latents=latents) | |
batch_logits.append(logits) | |
grid_logits = torch.cat(batch_logits, dim=1) | |
grid_logits = grid_logits.view((batch_size, *grid_size)).float() | |
return grid_logits | |