Spaces:
Running
on
L40S
Running
on
L40S
import argparse | |
import os | |
import torch | |
import trimesh | |
from cube3d.inference.engine import Engine, EngineFast | |
from cube3d.mesh_utils.postprocessing import ( | |
PYMESHLAB_AVAILABLE, | |
create_pymeshset, | |
postprocess_mesh, | |
save_mesh, | |
) | |
from cube3d.renderer import renderer | |
def generate_mesh( | |
engine, | |
prompt, | |
output_dir, | |
output_name, | |
resolution_base=8.0, | |
disable_postprocess=False, | |
top_p=None, | |
): | |
mesh_v_f = engine.t2s( | |
[prompt], | |
use_kv_cache=True, | |
resolution_base=resolution_base, | |
top_p=top_p, | |
) | |
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1] | |
obj_path = os.path.join(output_dir, f"{output_name}.obj") | |
if PYMESHLAB_AVAILABLE: | |
ms = create_pymeshset(vertices, faces) | |
if not disable_postprocess: | |
target_face_num = max(10000, int(faces.shape[0] * 0.1)) | |
print(f"Postprocessing mesh to {target_face_num} faces") | |
postprocess_mesh(ms, target_face_num, obj_path) | |
save_mesh(ms, obj_path) | |
else: | |
print( | |
"WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing." | |
) | |
mesh = trimesh.Trimesh(vertices, faces) | |
mesh.export(obj_path) | |
return obj_path | |
if __name__ == "__main__": | |
parser = argparse.ArgumentParser(description="cube shape generation script") | |
parser.add_argument( | |
"--config-path", | |
type=str, | |
default="cube3d/configs/open_model.yaml", | |
help="Path to the configuration YAML file.", | |
) | |
parser.add_argument( | |
"--output-dir", | |
type=str, | |
default="outputs/", | |
help="Path to the output directory to store .obj and .gif files", | |
) | |
parser.add_argument( | |
"--gpt-ckpt-path", | |
type=str, | |
required=True, | |
help="Path to the main GPT checkpoint file.", | |
) | |
parser.add_argument( | |
"--shape-ckpt-path", | |
type=str, | |
required=True, | |
help="Path to the shape encoder/decoder checkpoint file.", | |
) | |
parser.add_argument( | |
"--fast-inference", | |
help="Use optimized inference", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--prompt", | |
type=str, | |
required=True, | |
help="Text prompt for generating a 3D mesh", | |
) | |
parser.add_argument( | |
"--top-p", | |
type=float, | |
default=None, | |
help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.", | |
) | |
parser.add_argument( | |
"--render-gif", | |
help="Render a turntable gif of the mesh", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--disable-postprocessing", | |
help="Disable postprocessing on the mesh. This will result in a mesh with more faces.", | |
default=False, | |
action="store_true", | |
) | |
parser.add_argument( | |
"--resolution-base", | |
type=float, | |
default=8.0, | |
help="Resolution base for the shape decoder.", | |
) | |
args = parser.parse_args() | |
os.makedirs(args.output_dir, exist_ok=True) | |
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") | |
print(f"Using device: {device}") | |
# Initialize engine based on fast_inference flag | |
if args.fast_inference: | |
print( | |
"Using cuda graphs, this will take some time to warmup and capture the graph." | |
) | |
engine = EngineFast( | |
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device | |
) | |
print("Compiled the graph.") | |
else: | |
engine = Engine( | |
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device | |
) | |
# Generate meshes based on input source | |
obj_path = generate_mesh( | |
engine, | |
args.prompt, | |
args.output_dir, | |
"output", | |
args.resolution_base, | |
args.disable_postprocessing, | |
args.top_p, | |
) | |
if args.render_gif: | |
gif_path = renderer.render_turntable(obj_path, args.output_dir) | |
print(f"Rendered turntable gif for {args.prompt} at `{gif_path}`") | |
print(f"Generated mesh for {args.prompt} at `{obj_path}`") | |