|
from typing import Dict |
|
|
|
import numpy as np |
|
import skimage |
|
import torch |
|
from tqdm.auto import tqdm |
|
|
|
from point_e.models.sdf import PointCloudSDFModel |
|
|
|
from .mesh import TriMesh |
|
from .point_cloud import PointCloud |
|
|
|
|
|
def marching_cubes_mesh( |
|
pc: PointCloud, |
|
model: PointCloudSDFModel, |
|
batch_size: int = 4096, |
|
grid_size: int = 128, |
|
side_length: float = 1.02, |
|
fill_vertex_channels: bool = True, |
|
progress: bool = False, |
|
) -> TriMesh: |
|
""" |
|
Run marching cubes on the SDF predicted from a point cloud to produce a |
|
mesh representing the 3D surface. |
|
|
|
:param pc: the point cloud to apply marching cubes to. |
|
:param model: the model to use to predict SDF values. |
|
:param grid_size: the number of samples along each axis. A total of |
|
grid_size**3 function evaluations are performed. |
|
:param side_length: the size of the cube containing the model, which is |
|
assumed to be centered at the origin. |
|
:param fill_vertex_channels: if True, use the nearest neighbor of each mesh |
|
vertex in the point cloud to compute vertex |
|
data (e.g. colors). |
|
""" |
|
voxel_size = side_length / (grid_size - 1) |
|
min_coord = -side_length / 2 |
|
|
|
def int_coord_to_float(int_coords: torch.Tensor) -> torch.Tensor: |
|
return int_coords.float() * voxel_size + min_coord |
|
|
|
with torch.no_grad(): |
|
cond = model.encode_point_clouds( |
|
torch.from_numpy(pc.coords).permute(1, 0).to(model.device)[None] |
|
) |
|
|
|
indices = range(0, grid_size**3, batch_size) |
|
if progress: |
|
indices = tqdm(indices) |
|
|
|
volume = [] |
|
for i in indices: |
|
indices = torch.arange( |
|
i, min(i + batch_size, grid_size**3), step=1, dtype=torch.int64, device=model.device |
|
) |
|
zs = int_coord_to_float(indices % grid_size) |
|
ys = int_coord_to_float(torch.div(indices, grid_size, rounding_mode="trunc") % grid_size) |
|
xs = int_coord_to_float(torch.div(indices, grid_size**2, rounding_mode="trunc")) |
|
coords = torch.stack([xs, ys, zs], dim=0) |
|
with torch.no_grad(): |
|
volume.append(model(coords[None], encoded=cond)[0]) |
|
volume_np = torch.cat(volume).view(grid_size, grid_size, grid_size).cpu().numpy() |
|
|
|
if np.all(volume_np < 0) or np.all(volume_np > 0): |
|
|
|
|
|
volume_np -= np.mean(volume_np) |
|
|
|
verts, faces, normals, _ = skimage.measure.marching_cubes( |
|
volume=volume_np, |
|
level=0, |
|
allow_degenerate=False, |
|
spacing=(voxel_size,) * 3, |
|
) |
|
|
|
|
|
|
|
|
|
|
|
old_f1 = faces[:, 0].copy() |
|
faces[:, 0] = faces[:, 1] |
|
faces[:, 1] = old_f1 |
|
|
|
verts += min_coord |
|
return TriMesh( |
|
verts=verts, |
|
faces=faces, |
|
normals=normals, |
|
vertex_channels=None if not fill_vertex_channels else _nearest_vertex_channels(pc, verts), |
|
) |
|
|
|
|
|
def _nearest_vertex_channels(pc: PointCloud, verts: np.ndarray) -> Dict[str, np.ndarray]: |
|
nearest = pc.nearest_points(verts) |
|
return {ch: arr[nearest] for ch, arr in pc.channels.items()} |
|
|