""" ----------------------------------------------------------------------------- Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved. NVIDIA CORPORATION and its licensors retain all intellectual property and proprietary rights in and to this software, related documentation and any modifications thereto. Any use, reproduction, disclosure or distribution of this software and related documentation without an express license agreement from NVIDIA CORPORATION is strictly prohibited. ----------------------------------------------------------------------------- """ import argparse import glob import importlib import os from datetime import datetime import fpsample import kiui import meshiki import numpy as np import torch import trimesh from vae.model import Model from vae.utils import box_normalize, postprocess_mesh, sphere_normalize, sync_timer # PYTHONPATH=. python vae/scripts/infer.py parser = argparse.ArgumentParser() parser.add_argument("--config", type=str, help="config file path", default="vae.configs.part_woenc") parser.add_argument( "--ckpt_path", type=str, help="checkpoint path", default="pretrained/vae.pt", ) parser.add_argument("--input", type=str, help="input directory", default="assets/meshes/") parser.add_argument("--output_dir", type=str, help="output directory", default="output/") parser.add_argument("--limit", type=int, help="how many samples to test", default=-1) parser.add_argument("--num_fps_point", type=int, help="number of fps points", default=1024) parser.add_argument("--num_fps_salient_point", type=int, help="number of fps salient points", default=1024) parser.add_argument("--grid_res", type=int, help="grid resolution", default=512) parser.add_argument("--seed", type=int, help="seed", default=42) args = parser.parse_args() TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32) kiui.seed_everything(args.seed) @sync_timer("prepare_input_from_mesh") def prepare_input_from_mesh(mesh_path, use_salient_point=True, num_fps_point=1024, num_fps_salient_point=1024): # load mesh, assume it's already processed to be watertight. mesh_name = mesh_path.split("/")[-1].split(".")[0] vertices, faces = meshiki.load_mesh(mesh_path) # vertices = sphere_normalize(vertices) vertices = box_normalize(vertices) mesh = meshiki.Mesh(vertices, faces) uniform_surface_points = mesh.uniform_point_sample(200000) uniform_surface_points = meshiki.fps(uniform_surface_points, 32768) # hardcoded... salient_surface_points = mesh.salient_point_sample(16384, thresh_bihedral=15) # save points # trimesh.PointCloud(vertices=uniform_surface_points).export(os.path.join(workspace, mesh_name + "_uniform.ply")) # trimesh.PointCloud(vertices=salient_surface_points).export(os.path.join(workspace, mesh_name + "_salient.ply")) sample = {} sample["pointcloud"] = torch.from_numpy(uniform_surface_points) # fps subsample fps_indices = fpsample.bucket_fps_kdline_sampling(uniform_surface_points, num_fps_point, h=5, start_idx=0) sample["fps_indices"] = torch.from_numpy(fps_indices).long() # [num_fps_point,] if use_salient_point: sample["pointcloud_dorases"] = torch.from_numpy(salient_surface_points) # [N', 3] # fps subsample fps_indices_dorases = fpsample.bucket_fps_kdline_sampling( salient_surface_points, num_fps_salient_point, h=5, start_idx=0 ) sample["fps_indices_dorases"] = torch.from_numpy(fps_indices_dorases).long() # [num_fps_point,] return sample print(f"Loading checkpoint from {args.ckpt_path}") ckpt_dict = torch.load(args.ckpt_path, weights_only=True) # delete all keys other than model if "model" in ckpt_dict: ckpt_dict = ckpt_dict["model"] # instantiate model print(f"Instantiating model from {args.config}") model_config = importlib.import_module(args.config).make_config() model = Model(model_config).eval().cuda().bfloat16() # load weight print(f"Loading weights from {args.ckpt_path}") model.load_state_dict(ckpt_dict, strict=True) timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") workspace = os.path.join(args.output_dir, "vae_" + args.config.split(".")[-1] + "_" + timestamp) if not os.path.exists(workspace): os.makedirs(workspace) else: os.system(f"rm {workspace}/*") print(f"Output directory: {workspace}") # load dataset mesh_list = glob.glob(os.path.join(args.input, "*")) mesh_list = mesh_list[: args.limit] if args.limit > 0 else mesh_list for i, mesh_path in enumerate(mesh_list): print(f"Processing {i}/{len(mesh_list)}: {mesh_path}") mesh_name = mesh_path.split("/")[-1].split(".")[0] sample = prepare_input_from_mesh( mesh_path, num_fps_point=args.num_fps_point, num_fps_salient_point=args.num_fps_salient_point ) for k in sample: sample[k] = sample[k].unsqueeze(0).cuda() # call vae with torch.inference_mode(): output = model(sample, resolution=args.grid_res) latent = output["latent"] vertices, faces = output["meshes"][0] mesh = trimesh.Trimesh(vertices, faces) mesh = postprocess_mesh(mesh, 5e5) mesh.export(f"{workspace}/{mesh_name}.glb")