File size: 5,152 Bytes
616f571
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import argparse
import logging

import numpy as np
import torch
import trimesh

from cube3d.inference.utils import load_config, load_model_weights, parse_structured
from cube3d.model.autoencoder.one_d_autoencoder import OneDAutoEncoder

MESH_SCALE = 0.96


def rescale(vertices: np.ndarray, mesh_scale: float = MESH_SCALE) -> np.ndarray:
    """Rescale the vertices to a cube, e.g., [-1, -1, -1] to [1, 1, 1] when mesh_scale=1.0"""
    vertices = vertices
    bbmin = vertices.min(0)
    bbmax = vertices.max(0)
    center = (bbmin + bbmax) * 0.5
    scale = 2.0 * mesh_scale / (bbmax - bbmin).max()
    vertices = (vertices - center) * scale
    return vertices


def load_scaled_mesh(file_path: str) -> trimesh.Trimesh:
    """
    Load a mesh and scale it to a unit cube, and clean the mesh.
    Parameters:
        file_obj: str | IO
        file_type: str
    Returns:
        mesh: trimesh.Trimesh
    """
    mesh: trimesh.Trimesh = trimesh.load(file_path, force="mesh")
    mesh.remove_infinite_values()
    mesh.update_faces(mesh.nondegenerate_faces())
    mesh.update_faces(mesh.unique_faces())
    mesh.remove_unreferenced_vertices()
    if len(mesh.vertices) == 0 or len(mesh.faces) == 0:
        raise ValueError("Mesh has no vertices or faces after cleaning")
    mesh.vertices = rescale(mesh.vertices)
    return mesh


def load_and_process_mesh(file_path: str, n_samples: int = 8192):
    """
    Loads a 3D mesh from the specified file path, samples points from its surface,
    and processes the sampled points into a point cloud with normals.
    Args:
        file_path (str): The file path to the 3D mesh file.
        n_samples (int, optional): The number of points to sample from the mesh surface. Defaults to 8192.
    Returns:
        torch.Tensor: A tensor of shape (1, n_samples, 6) containing the processed point cloud.
                        Each point consists of its 3D position (x, y, z) and its normal vector (nx, ny, nz).
    """

    mesh = load_scaled_mesh(file_path)
    positions, face_indices = trimesh.sample.sample_surface(mesh, n_samples)
    normals = mesh.face_normals[face_indices]
    point_cloud = np.concatenate(
        [positions, normals], axis=1
    )  # Shape: (num_samples, 6)
    point_cloud = torch.from_numpy(point_cloud.reshape(1, -1, 6)).float()
    return point_cloud


@torch.inference_mode()
def run_shape_decode(
    shape_model: OneDAutoEncoder,
    output_ids: torch.Tensor,
    resolution_base: float = 8.0,
    chunk_size: int = 100_000,
):
    """
    Decodes the shape from the given output IDs and extracts the geometry.
    Args:
        shape_model (OneDAutoEncoder): The shape model.
        output_ids (torch.Tensor): The tensor containing the output IDs.
        resolution_base (float, optional): The base resolution for geometry extraction. Defaults to 8.43.
        chunk_size (int, optional): The chunk size for processing. Defaults to 100,000.
    Returns:
        tuple: A tuple containing the vertices and faces of the mesh.
    """
    shape_ids = (
        output_ids[:, : shape_model.cfg.num_encoder_latents, ...]
        .clamp_(0, shape_model.cfg.num_codes - 1)
        .view(-1, shape_model.cfg.num_encoder_latents)
    )
    latents = shape_model.decode_indices(shape_ids)
    mesh_v_f, _ = shape_model.extract_geometry(
        latents,
        resolution_base=resolution_base,
        chunk_size=chunk_size,
        use_warp=True,
    )
    return mesh_v_f


if __name__ == "__main__":
    parser = argparse.ArgumentParser(
        description="cube shape encode and decode example script"
    )
    parser.add_argument(
        "--mesh-path",
        type=str,
        required=True,
        help="Path to the input mesh file.",
    )
    parser.add_argument(
        "--config-path",
        type=str,
        default="cube3d/configs/open_model.yaml",
        help="Path to the configuration YAML file.",
    )
    parser.add_argument(
        "--shape-ckpt-path",
        type=str,
        required=True,
        help="Path to the shape encoder/decoder checkpoint file.",
    )
    parser.add_argument(
        "--recovered-mesh-path",
        type=str,
        default="recovered_mesh.obj",
        help="Path to save the recovered mesh file.",
    )
    args = parser.parse_args()
    device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
    logging.info(f"Using device: {device}")

    cfg = load_config(args.config_path)

    shape_model = OneDAutoEncoder(
        parse_structured(OneDAutoEncoder.Config, cfg.shape_model)
    )
    load_model_weights(
        shape_model,
        args.shape_ckpt_path,
    )
    shape_model = shape_model.eval().to(device)
    point_cloud = load_and_process_mesh(args.mesh_path)
    output = shape_model.encode(point_cloud.to(device))
    indices = output[3]["indices"]
    print("Got the following shape indices:")
    print(indices)
    print("Indices shape: ", indices.shape)
    mesh_v_f = run_shape_decode(shape_model, indices)
    vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
    mesh = trimesh.Trimesh(vertices=vertices, faces=faces)
    mesh.export(args.recovered_mesh_path)