Spaces:
Running
on
L40S
Running
on
L40S
File size: 5,152 Bytes
616f571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import argparse
import logging
import numpy as np
import torch
import trimesh
from cube3d.inference.utils import load_config, load_model_weights, parse_structured
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder
MESH_SCALE = 0.96
def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
"""Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
vertices = vertices
bbmin = vertices.min(0)
bbmax = vertices.max(0)
center = (bbmin + bbmax) * 0.5
scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
vertices = (vertices - center) * scale
return vertices
def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
"""
Load a mesh and scale it to a unit cube, and clean the mesh.
Parameters:
file_obj: str | IO
file_type: str
Returns:
mesh: trimesh.Trimesh
"""
mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
mesh.remove_infinite_values()
mesh.update_faces(mesh.nondegenerate_faces())
mesh.update_faces(mesh.unique_faces())
mesh.remove_unreferenced_vertices()
if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
raise ValueError("Mesh has no vertices or faces after cleaning")
mesh.vertices = rescale(mesh.vertices)
return mesh
def load_and_process_mesh(file_path: str, n_samples: int = 8192):
"""
Loads a 3D mesh from the specified file path, samples points from its surface,
and processes the sampled points into a point cloud with normals.
Args:
file_path (str): The file path to the 3D mesh file.
n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
Returns:
torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
"""
mesh = load_scaled_mesh(file_path)
positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
normals = mesh.face_normals[face_indices]
point_cloud = np.concatenate(
[positions, normals], axis=1
) # Shape: (num_samples, 6)
point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
return point_cloud
@torch.inference_mode()
def run_shape_decode(
shape_model: OneDAutoEncoder,
output_ids: torch.Tensor,
resolution_base: float = 8.0,
chunk_size: int = 100_000,
):
"""
Decodes the shape from the given output IDs and extracts the geometry.
Args:
shape_model (OneDAutoEncoder): The shape model.
output_ids (torch.Tensor): The tensor containing the output IDs.
resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
Returns:
tuple: A tuple containing the vertices and faces of the mesh.
"""
shape_ids = (
output_ids[:, : shape_model.cfg.num_encoder_latents, ...]
.clamp_(0, shape_model.cfg.num_codes - 1)
.view(-1, shape_model.cfg.num_encoder_latents)
)
latents = shape_model.decode_indices(shape_ids)
mesh_v_f, _ = shape_model.extract_geometry(
latents,
resolution_base=resolution_base,
chunk_size=chunk_size,
use_warp=True,
)
return mesh_v_f
if __name__ == "__main__":
parser = argparse.ArgumentParser(
description="cube shape encode and decode example script"
)
parser.add_argument(
"--mesh-path",
type=str,
required=True,
help="Path to the input mesh file.",
)
parser.add_argument(
"--config-path",
type=str,
default="cube3d/configs/open_model.yaml",
help="Path to the configuration YAML file.",
)
parser.add_argument(
"--shape-ckpt-path",
type=str,
required=True,
help="Path to the shape encoder/decoder checkpoint file.",
)
parser.add_argument(
"--recovered-mesh-path",
type=str,
default="recovered_mesh.obj",
help="Path to save the recovered mesh file.",
)
args = parser.parse_args()
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
logging.info(f"Using device: {device}")
cfg = load_config(args.config_path)
shape_model = OneDAutoEncoder(
parse_structured(OneDAutoEncoder.Config, cfg.shape_model)
)
load_model_weights(
shape_model,
args.shape_ckpt_path,
)
shape_model = shape_model.eval().to(device)
point_cloud = load_and_process_mesh(args.mesh_path)
output = shape_model.encode(point_cloud.to(device))
indices = output[3]["indices"]
print("Got the following shape indices:")
print(indices)
print("Indices shape: ", indices.shape)
mesh_v_f = run_shape_decode(shape_model, indices)
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
mesh.export(args.recovered_mesh_path)
|