File size: 5,208 Bytes
daa6779
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
-----------------------------------------------------------------------------
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.

NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
"""

import argparse
import glob
import importlib
import os
from datetime import datetime

import fpsample
import kiui
import meshiki
import numpy as np
import torch
import trimesh

from vae.model import Model
from vae.utils import box_normalize, postprocess_mesh, sphere_normalize, sync_timer

# PYTHONPATH=. python vae/scripts/infer.py
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="config file path", default="vae.configs.part_woenc")
parser.add_argument(
    "--ckpt_path",
    type=str,
    help="checkpoint path",
    default="pretrained/vae.pt",
)
parser.add_argument("--input", type=str, help="input directory", default="assets/meshes/")
parser.add_argument("--output_dir", type=str, help="output directory", default="output/")
parser.add_argument("--limit", type=int, help="how many samples to test", default=-1)
parser.add_argument("--num_fps_point", type=int, help="number of fps points", default=1024)
parser.add_argument("--num_fps_salient_point", type=int, help="number of fps salient points", default=1024)
parser.add_argument("--grid_res", type=int, help="grid resolution", default=512)
parser.add_argument("--seed", type=int, help="seed", default=42)
args = parser.parse_args()


TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)

kiui.seed_everything(args.seed)


@sync_timer("prepare_input_from_mesh")
def prepare_input_from_mesh(mesh_path, use_salient_point=True, num_fps_point=1024, num_fps_salient_point=1024):
    # load mesh, assume it's already processed to be watertight.

    mesh_name = mesh_path.split("/")[-1].split(".")[0]
    vertices, faces = meshiki.load_mesh(mesh_path)

    # vertices = sphere_normalize(vertices)
    vertices = box_normalize(vertices)

    mesh = meshiki.Mesh(vertices, faces)

    uniform_surface_points = mesh.uniform_point_sample(200000)
    uniform_surface_points = meshiki.fps(uniform_surface_points, 32768)  # hardcoded...
    salient_surface_points = mesh.salient_point_sample(16384, thresh_bihedral=15)

    # save points
    # trimesh.PointCloud(vertices=uniform_surface_points).export(os.path.join(workspace, mesh_name + "_uniform.ply"))
    # trimesh.PointCloud(vertices=salient_surface_points).export(os.path.join(workspace, mesh_name + "_salient.ply"))

    sample = {}

    sample["pointcloud"] = torch.from_numpy(uniform_surface_points)

    # fps subsample
    fps_indices = fpsample.bucket_fps_kdline_sampling(uniform_surface_points, num_fps_point, h=5, start_idx=0)
    sample["fps_indices"] = torch.from_numpy(fps_indices).long()  # [num_fps_point,]

    if use_salient_point:
        sample["pointcloud_dorases"] = torch.from_numpy(salient_surface_points)  # [N', 3]

        # fps subsample
        fps_indices_dorases = fpsample.bucket_fps_kdline_sampling(
            salient_surface_points, num_fps_salient_point, h=5, start_idx=0
        )
        sample["fps_indices_dorases"] = torch.from_numpy(fps_indices_dorases).long()  # [num_fps_point,]

    return sample


print(f"Loading checkpoint from {args.ckpt_path}")
ckpt_dict = torch.load(args.ckpt_path, weights_only=True)

# delete all keys other than model
if "model" in ckpt_dict:
    ckpt_dict = ckpt_dict["model"]

# instantiate model
print(f"Instantiating model from {args.config}")
model_config = importlib.import_module(args.config).make_config()
model = Model(model_config).eval().cuda().bfloat16()

# load weight
print(f"Loading weights from {args.ckpt_path}")
model.load_state_dict(ckpt_dict, strict=True)

timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
workspace = os.path.join(args.output_dir, "vae_" + args.config.split(".")[-1] + "_" + timestamp)
if not os.path.exists(workspace):
    os.makedirs(workspace)
else:
    os.system(f"rm {workspace}/*")
print(f"Output directory: {workspace}")

# load dataset
mesh_list = glob.glob(os.path.join(args.input, "*"))
mesh_list = mesh_list[: args.limit] if args.limit > 0 else mesh_list

for i, mesh_path in enumerate(mesh_list):
    print(f"Processing {i}/{len(mesh_list)}: {mesh_path}")

    mesh_name = mesh_path.split("/")[-1].split(".")[0]

    sample = prepare_input_from_mesh(
        mesh_path, num_fps_point=args.num_fps_point, num_fps_salient_point=args.num_fps_salient_point
    )
    for k in sample:
        sample[k] = sample[k].unsqueeze(0).cuda()

    # call vae
    with torch.inference_mode():
        output = model(sample, resolution=args.grid_res)

    latent = output["latent"]
    vertices, faces = output["meshes"][0]

    mesh = trimesh.Trimesh(vertices, faces)
    mesh = postprocess_mesh(mesh, 5e5)

    mesh.export(f"{workspace}/{mesh_name}.glb")