Spaces:
Running
on
L40S
Running
on
L40S
File size: 3,675 Bytes
be383f9 6d250b3 be383f9 7e70d08 be383f9 7e70d08 6d250b3 be383f9 7836010 be383f9 7e70d08 be383f9 7e70d08 be383f9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 |
import argparse
import gradio as gr
import os
import torch
import trimesh
import sys
sys.path.append("cube")
from cube3d.inference.engine import EngineFast
from pathlib import Path
import uuid
import shutil
from huggingface_hub import snapshot_download
GLOBAL_STATE = {}
def gen_save_folder(max_size=200):
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]
if len(dirs) >= max_size:
oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
shutil.rmtree(oldest_dir)
print(f"Removed the oldest folder: {oldest_dir}")
new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
os.makedirs(new_folder, exist_ok=True)
print(f"Created new folder: {new_folder}")
return new_folder
def handle_text_prompt(input_prompt):
mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0)
# save output
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
save_folder = gen_save_folder()
output_path = os.path.join(save_folder, "output.obj")
trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
return output_path
def build_interface():
"""Build UI for gradio app
"""
title = "Cube 3D"
with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
gr.Markdown(
f"""
# {title}
"""
)
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
input_text_box = gr.Textbox(
value=None,
label="Prompt",
lines=2,
)
with gr.Row():
submit_button = gr.Button("Submit", variant="primary")
with gr.Column(scale=3):
model3d = gr.Model3D(
label="Output", height="45em", interactive=False
)
submit_button.click(
handle_text_prompt,
inputs=[
input_text_box
],
outputs=[
model3d
]
)
return interface
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path",
type=str,
help="Path to the config file",
default="cube/cube3d/configs/open_model.yaml",
)
parser.add_argument(
"--gpt_ckpt_path",
type=str,
help="Path to the gpt ckpt path",
default="model_weights/shape_gpt.safetensors",
)
parser.add_argument(
"--shape_ckpt_path",
type=str,
help="Path to the shape ckpt path",
default="model_weights/shape_tokenizer.safetensors",
)
parser.add_argument(
"--save_dir",
type=str,
default="gradio_save_dir",
)
args = parser.parse_args()
snapshot_download(
repo_id="Roblox/cube3d-v0.1",
local_dir="./model_weights"
)
config_path = "./model_weights/shape_tokenizer.safetensors"
gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
engine_fast = EngineFast(
config_path,
gpt_ckpt_path,
shape_ckpt_path,
device=torch.device("cuda"),
)
GLOBAL_STATE["engine_fast"] = engine_fast
GLOBAL_STATE["SAVE_DIR"] = args.save_dir
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
demo = build_interface()
demo.queue(default_concurrency_limit=None)
demo.launch()
|