Spaces:
Running
on
Zero
Running
on
Zero
File size: 5,208 Bytes
daa6779 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 |
"""
-----------------------------------------------------------------------------
Copyright (c) 2024, NVIDIA CORPORATION. All rights reserved.
NVIDIA CORPORATION and its licensors retain all intellectual property
and proprietary rights in and to this software, related documentation
and any modifications thereto. Any use, reproduction, disclosure or
distribution of this software and related documentation without an express
license agreement from NVIDIA CORPORATION is strictly prohibited.
-----------------------------------------------------------------------------
"""
import argparse
import glob
import importlib
import os
from datetime import datetime
import fpsample
import kiui
import meshiki
import numpy as np
import torch
import trimesh
from vae.model import Model
from vae.utils import box_normalize, postprocess_mesh, sphere_normalize, sync_timer
# PYTHONPATH=. python vae/scripts/infer.py
parser = argparse.ArgumentParser()
parser.add_argument("--config", type=str, help="config file path", default="vae.configs.part_woenc")
parser.add_argument(
"--ckpt_path",
type=str,
help="checkpoint path",
default="pretrained/vae.pt",
)
parser.add_argument("--input", type=str, help="input directory", default="assets/meshes/")
parser.add_argument("--output_dir", type=str, help="output directory", default="output/")
parser.add_argument("--limit", type=int, help="how many samples to test", default=-1)
parser.add_argument("--num_fps_point", type=int, help="number of fps points", default=1024)
parser.add_argument("--num_fps_salient_point", type=int, help="number of fps salient points", default=1024)
parser.add_argument("--grid_res", type=int, help="grid resolution", default=512)
parser.add_argument("--seed", type=int, help="seed", default=42)
args = parser.parse_args()
TRIMESH_GLB_EXPORT = np.array([[0, 1, 0], [0, 0, 1], [1, 0, 0]]).astype(np.float32)
kiui.seed_everything(args.seed)
@sync_timer("prepare_input_from_mesh")
def prepare_input_from_mesh(mesh_path, use_salient_point=True, num_fps_point=1024, num_fps_salient_point=1024):
# load mesh, assume it's already processed to be watertight.
mesh_name = mesh_path.split("/")[-1].split(".")[0]
vertices, faces = meshiki.load_mesh(mesh_path)
# vertices = sphere_normalize(vertices)
vertices = box_normalize(vertices)
mesh = meshiki.Mesh(vertices, faces)
uniform_surface_points = mesh.uniform_point_sample(200000)
uniform_surface_points = meshiki.fps(uniform_surface_points, 32768) # hardcoded...
salient_surface_points = mesh.salient_point_sample(16384, thresh_bihedral=15)
# save points
# trimesh.PointCloud(vertices=uniform_surface_points).export(os.path.join(workspace, mesh_name + "_uniform.ply"))
# trimesh.PointCloud(vertices=salient_surface_points).export(os.path.join(workspace, mesh_name + "_salient.ply"))
sample = {}
sample["pointcloud"] = torch.from_numpy(uniform_surface_points)
# fps subsample
fps_indices = fpsample.bucket_fps_kdline_sampling(uniform_surface_points, num_fps_point, h=5, start_idx=0)
sample["fps_indices"] = torch.from_numpy(fps_indices).long() # [num_fps_point,]
if use_salient_point:
sample["pointcloud_dorases"] = torch.from_numpy(salient_surface_points) # [N', 3]
# fps subsample
fps_indices_dorases = fpsample.bucket_fps_kdline_sampling(
salient_surface_points, num_fps_salient_point, h=5, start_idx=0
)
sample["fps_indices_dorases"] = torch.from_numpy(fps_indices_dorases).long() # [num_fps_point,]
return sample
print(f"Loading checkpoint from {args.ckpt_path}")
ckpt_dict = torch.load(args.ckpt_path, weights_only=True)
# delete all keys other than model
if "model" in ckpt_dict:
ckpt_dict = ckpt_dict["model"]
# instantiate model
print(f"Instantiating model from {args.config}")
model_config = importlib.import_module(args.config).make_config()
model = Model(model_config).eval().cuda().bfloat16()
# load weight
print(f"Loading weights from {args.ckpt_path}")
model.load_state_dict(ckpt_dict, strict=True)
timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
workspace = os.path.join(args.output_dir, "vae_" + args.config.split(".")[-1] + "_" + timestamp)
if not os.path.exists(workspace):
os.makedirs(workspace)
else:
os.system(f"rm {workspace}/*")
print(f"Output directory: {workspace}")
# load dataset
mesh_list = glob.glob(os.path.join(args.input, "*"))
mesh_list = mesh_list[: args.limit] if args.limit > 0 else mesh_list
for i, mesh_path in enumerate(mesh_list):
print(f"Processing {i}/{len(mesh_list)}: {mesh_path}")
mesh_name = mesh_path.split("/")[-1].split(".")[0]
sample = prepare_input_from_mesh(
mesh_path, num_fps_point=args.num_fps_point, num_fps_salient_point=args.num_fps_salient_point
)
for k in sample:
sample[k] = sample[k].unsqueeze(0).cuda()
# call vae
with torch.inference_mode():
output = model(sample, resolution=args.grid_res)
latent = output["latent"]
vertices, faces = output["meshes"][0]
mesh = trimesh.Trimesh(vertices, faces)
mesh = postprocess_mesh(mesh, 5e5)
mesh.export(f"{workspace}/{mesh_name}.glb")
|