cube3d-interactive / cube /cube3d /vq_vae_encode_decode.py
Akash Garg
adding cube sources
616f571
import argparse
import logging
import numpy as np
import torch
import trimesh
from cube3d.inference.utils import load_config, load_model_weights, parse_structured
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
MESH_SCALE = 0.96
def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
"""Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
vertices = vertices
bbmin = vertices.min(0)
bbmax = vertices.max(0)
center = (bbmin + bbmax) * 0.5
scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
vertices = (vertices - center) * scale
return vertices
def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
"""
Load a mesh and scale it to a unit cube, and clean the mesh.
Parameters:
file_obj: str | IO
file_type: str
Returns:
mesh: trimesh.Trimesh
"""
mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
mesh.remove_infinite_values()
mesh.update_faces(mesh.nondegenerate_faces())
mesh.update_faces(mesh.unique_faces())
mesh.remove_unreferenced_vertices()
if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
raise ValueError("Mesh has no vertices or faces after cleaning")
mesh.vertices = rescale(mesh.vertices)
return mesh
def load_and_process_mesh(file_path: str, n_samples: int = 8192):
"""
Loads a 3D mesh from the specified file path, samples points from its surface,
and processes the sampled points into a point cloud with normals.
Args:
file_path (str): The file path to the 3D mesh file.
n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
Returns:
torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
"""
mesh = load_scaled_mesh(file_path)
positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
normals = mesh.face_normals[face_indices]
point_cloud = np.concatenate(
[positions, normals], axis=1
) # Shape: (num_samples, 6)
point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
return point_cloud
@torch.inference_mode()
def run_shape_decode(
shape_model: OneDAutoEncoder,
output_ids: torch.Tensor,
resolution_base: float = 8.0,
chunk_size: int = 100_000,
):
"""
Decodes the shape from the given output IDs and extracts the geometry.
Args:
shape_model (OneDAutoEncoder): The shape model.
output_ids (torch.Tensor): The tensor containing the output IDs.
resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
Returns:
tuple: A tuple containing the vertices and faces of the mesh.
"""
shape_ids = (
output_ids[:, : shape_model.cfg.num_encoder_latents, ...]
.clamp_(0, shape_model.cfg.num_codes - 1)
.view(-1, shape_model.cfg.num_encoder_latents)
)
latents = shape_model.decode_indices(shape_ids)
mesh_v_f, _ = shape_model.extract_geometry(
latents,
resolution_base=resolution_base,
chunk_size=chunk_size,
use_warp=True,
)
return mesh_v_f
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="cube shape encode and decode example script"
)
parser.add_argument(
"--mesh-path",
type=str,
required=True,
help="Path to the input mesh file.",
)
parser.add_argument(
"--config-path",
type=str,
default="cube3d/configs/open_model.yaml",
help="Path to the configuration YAML file.",
)
parser.add_argument(
"--shape-ckpt-path",
type=str,
required=True,
help="Path to the shape encoder/decoder checkpoint file.",
)
parser.add_argument(
"--recovered-mesh-path",
type=str,
default="recovered_mesh.obj",
help="Path to save the recovered mesh file.",
)
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logging.info(f"Using device: {device}")
cfg = load_config(args.config_path)
shape_model = OneDAutoEncoder(
parse_structured(OneDAutoEncoder.Config, cfg.shape_model)
)
load_model_weights(
shape_model,
args.shape_ckpt_path,
)
shape_model = shape_model.eval().to(device)
point_cloud = load_and_process_mesh(args.mesh_path)
output = shape_model.encode(point_cloud.to(device))
indices = output[3]["indices"]
print("Got the following shape indices:")
print(indices)
print("Indices shape: ", indices.shape)
mesh_v_f = run_shape_decode(shape_model, indices)
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
mesh.export(args.recovered_mesh_path)