Akash Garg
adding variance slider for top_p
f6a2f50
import argparse
import os
import torch
import trimesh
from cube3d.inference.engine import Engine, EngineFast
from cube3d.mesh_utils.postprocessing import (
PYMESHLAB_AVAILABLE,
create_pymeshset,
postprocess_mesh,
save_mesh,
)
from cube3d.renderer import renderer
def generate_mesh(
engine,
prompt,
output_dir,
output_name,
resolution_base=8.0,
disable_postprocess=False,
top_p=None,
):
mesh_v_f = engine.t2s(
[prompt],
use_kv_cache=True,
resolution_base=resolution_base,
top_p=top_p,
)
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
obj_path = os.path.join(output_dir, f"{output_name}.obj")
if PYMESHLAB_AVAILABLE:
ms = create_pymeshset(vertices, faces)
if not disable_postprocess:
target_face_num = max(10000, int(faces.shape[0] * 0.1))
print(f"Postprocessing mesh to {target_face_num} faces")
postprocess_mesh(ms, target_face_num, obj_path)
save_mesh(ms, obj_path)
else:
print(
"WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing."
)
mesh = trimesh.Trimesh(vertices, faces)
mesh.export(obj_path)
return obj_path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="cube shape generation script")
parser.add_argument(
"--config-path",
type=str,
default="cube3d/configs/open_model.yaml",
help="Path to the configuration YAML file.",
)
parser.add_argument(
"--output-dir",
type=str,
default="outputs/",
help="Path to the output directory to store .obj and .gif files",
)
parser.add_argument(
"--gpt-ckpt-path",
type=str,
required=True,
help="Path to the main GPT checkpoint file.",
)
parser.add_argument(
"--shape-ckpt-path",
type=str,
required=True,
help="Path to the shape encoder/decoder checkpoint file.",
)
parser.add_argument(
"--fast-inference",
help="Use optimized inference",
default=False,
action="store_true",
)
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Text prompt for generating a 3D mesh",
)
parser.add_argument(
"--top-p",
type=float,
default=None,
help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.",
)
parser.add_argument(
"--render-gif",
help="Render a turntable gif of the mesh",
default=False,
action="store_true",
)
parser.add_argument(
"--disable-postprocessing",
help="Disable postprocessing on the mesh. This will result in a mesh with more faces.",
default=False,
action="store_true",
)
parser.add_argument(
"--resolution-base",
type=float,
default=8.0,
help="Resolution base for the shape decoder.",
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")
# Initialize engine based on fast_inference flag
if args.fast_inference:
print(
"Using cuda graphs, this will take some time to warmup and capture the graph."
)
engine = EngineFast(
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
)
print("Compiled the graph.")
else:
engine = Engine(
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
)
# Generate meshes based on input source
obj_path = generate_mesh(
engine,
args.prompt,
args.output_dir,
"output",
args.resolution_base,
args.disable_postprocessing,
args.top_p,
)
if args.render_gif:
gif_path = renderer.render_turntable(obj_path, args.output_dir)
print(f"Rendered turntable gif for {args.prompt} at `{gif_path}`")
print(f"Generated mesh for {args.prompt} at `{obj_path}`")