Spaces:
Running
on
L40S
Running
on
L40S
File size: 4,295 Bytes
616f571 f6a2f50 616f571 f6a2f50 616f571 f6a2f50 616f571 f6a2f50 616f571 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
import argparse
import os
import torch
import trimesh
from cube3d.inference.engine import Engine, EngineFast
from cube3d.mesh_utils.postprocessing import (
PYMESHLAB_AVAILABLE,
create_pymeshset,
postprocess_mesh,
save_mesh,
)
from cube3d.renderer import renderer
def generate_mesh(
engine,
prompt,
output_dir,
output_name,
resolution_base=8.0,
disable_postprocess=False,
top_p=None,
):
mesh_v_f = engine.t2s(
[prompt],
use_kv_cache=True,
resolution_base=resolution_base,
top_p=top_p,
)
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
obj_path = os.path.join(output_dir, f"{output_name}.obj")
if PYMESHLAB_AVAILABLE:
ms = create_pymeshset(vertices, faces)
if not disable_postprocess:
target_face_num = max(10000, int(faces.shape[0] * 0.1))
print(f"Postprocessing mesh to {target_face_num} faces")
postprocess_mesh(ms, target_face_num, obj_path)
save_mesh(ms, obj_path)
else:
print(
"WARNING: pymeshlab is not available, using trimesh to export obj and skipping optional post processing."
)
mesh = trimesh.Trimesh(vertices, faces)
mesh.export(obj_path)
return obj_path
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="cube shape generation script")
parser.add_argument(
"--config-path",
type=str,
default="cube3d/configs/open_model.yaml",
help="Path to the configuration YAML file.",
)
parser.add_argument(
"--output-dir",
type=str,
default="outputs/",
help="Path to the output directory to store .obj and .gif files",
)
parser.add_argument(
"--gpt-ckpt-path",
type=str,
required=True,
help="Path to the main GPT checkpoint file.",
)
parser.add_argument(
"--shape-ckpt-path",
type=str,
required=True,
help="Path to the shape encoder/decoder checkpoint file.",
)
parser.add_argument(
"--fast-inference",
help="Use optimized inference",
default=False,
action="store_true",
)
parser.add_argument(
"--prompt",
type=str,
required=True,
help="Text prompt for generating a 3D mesh",
)
parser.add_argument(
"--top-p",
type=float,
default=None,
help="Float < 1: Keep smallest set of tokens with cumulative probability ≥ top_p. Default None: deterministic generation.",
)
parser.add_argument(
"--render-gif",
help="Render a turntable gif of the mesh",
default=False,
action="store_true",
)
parser.add_argument(
"--disable-postprocessing",
help="Disable postprocessing on the mesh. This will result in a mesh with more faces.",
default=False,
action="store_true",
)
parser.add_argument(
"--resolution-base",
type=float,
default=8.0,
help="Resolution base for the shape decoder.",
)
args = parser.parse_args()
os.makedirs(args.output_dir, exist_ok=True)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
print(f"Using device: {device}")
# Initialize engine based on fast_inference flag
if args.fast_inference:
print(
"Using cuda graphs, this will take some time to warmup and capture the graph."
)
engine = EngineFast(
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
)
print("Compiled the graph.")
else:
engine = Engine(
args.config_path, args.gpt_ckpt_path, args.shape_ckpt_path, device=device
)
# Generate meshes based on input source
obj_path = generate_mesh(
engine,
args.prompt,
args.output_dir,
"output",
args.resolution_base,
args.disable_postprocessing,
args.top_p,
)
if args.render_gif:
gif_path = renderer.render_turntable(obj_path, args.output_dir)
print(f"Rendered turntable gif for {args.prompt} at `{gif_path}`")
print(f"Generated mesh for {args.prompt} at `{obj_path}`")
|