import argparse import os import torch import trimesh from cube3d.inference.engine import Engine, EngineFast from cube3d.mesh_utils.postprocessing import ( PYMESHLAB_AVAILABLE, create_pymeshset, postprocess_mesh, save_mesh, ) from cube3d.renderer import renderer def generate_mesh( engine, prompt, output_dir, output_name, resolution_base=8.0, disable_postprocess=False, top_p=None, ): mesh_v_f = engine.t2s( [prompt], use_kv_cache=True, resolution_base=resolution_base, top_p=top_p, ) vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] obj_path = os.path.join(output_dir, f"{output_name}.obj") if PYMESHLAB_AVAILABLE: ms = create_pymeshset(vertices, faces) if not disable_postprocess: target_face_num = max(10000, int(faces.shape[0] * 0.1)) print(f"Postprocessing mesh to {target_face_num} faces") postprocess_mesh(ms, target_face_num, obj_path) save_mesh(ms, obj_path) else: print( "WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing." ) mesh = trimesh.Trimesh(vertices, faces) mesh.export(obj_path) return obj_path if __name__ == "__main__": parser = argparse.ArgumentParser(description="cube shape generation script") parser.add_argument( "--config-path", type=str, default="cube3d/configs/open_model.yaml", help="Path to the configuration YAML file.", ) parser.add_argument( "--output-dir", type=str, default="outputs/", help="Path to the output directory to store .obj and .gif files", ) parser.add_argument( "--gpt-ckpt-path", type=str, required=True, help="Path to the main GPT checkpoint file.", ) parser.add_argument( "--shape-ckpt-path", type=str, required=True, help="Path to the shape encoder/decoder checkpoint file.", ) parser.add_argument( "--fast-inference", help="Use optimized inference", default=False, action="store_true", ) parser.add_argument( "--prompt", type=str, required=True, help="Text prompt for generating a 3D mesh", ) parser.add_argument( "--top-p", type=float, default=None, help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.", ) parser.add_argument( "--render-gif", help="Render a turntable gif of the mesh", default=False, action="store_true", ) parser.add_argument( "--disable-postprocessing", help="Disable postprocessing on the mesh. This will result in a mesh with more faces.", default=False, action="store_true", ) parser.add_argument( "--resolution-base", type=float, default=8.0, help="Resolution base for the shape decoder.", ) args = parser.parse_args() os.makedirs(args.output_dir, exist_ok=True) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") print(f"Using device: {device}") # Initialize engine based on fast_inference flag if args.fast_inference: print( "Using cuda graphs, this will take some time to warmup and capture the graph." ) engine = EngineFast( args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device ) print("Compiled the graph.") else: engine = Engine( args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device ) # Generate meshes based on input source obj_path = generate_mesh( engine, args.prompt, args.output_dir, "output", args.resolution_base, args.disable_postprocessing, args.top_p, ) if args.render_gif: gif_path = renderer.render_turntable(obj_path, args.output_dir) print(f"Rendered turntable gif for {args.prompt} at `{gif_path}`") print(f"Generated mesh for {args.prompt} at `{obj_path}`")