animan123's picture
Move slider next to input
466d1ab
raw
history blame
4.2 kB
import argparse
import gradio as gr
import os
import torch
import trimesh
import sys
from pathlib import Path
pathdir = Path(__file__).parent / 'cube'
sys.path.append(pathdir.as_posix())
# print(__file__)
# print(os.listdir())
# print(os.listdir('cube'))
# print(pathdir.as_posix())
from cube3d.inference.engine import EngineFast, Engine
from pathlib import Path
import uuid
import shutil
from huggingface_hub import snapshot_download
GLOBAL_STATE = {}
def gen_save_folder(max_size=200):
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]
if len(dirs) >= max_size:
oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
shutil.rmtree(oldest_dir)
print(f"Removed the oldest folder: {oldest_dir}")
new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
os.makedirs(new_folder, exist_ok=True)
print(f"Created new folder: {new_folder}")
return new_folder
def handle_text_prompt(input_prompt, variance):
print(f"prompt: {input_prompt}, variance: {variance}")
top_p = None if variance == 0 else (100 - variance) / 100.0
mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p)
# save output
vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
save_folder = gen_save_folder()
output_path = os.path.join(save_folder, "output.glb")
trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
return output_path
def build_interface():
"""Build UI for gradio app
"""
title = "Cube 3D"
with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
gr.Markdown(
f"""
# {title}
# Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine!
"""
)
with gr.Row():
with gr.Column(scale=2):
with gr.Group():
input_text_box = gr.Textbox(
value=None,
label="Prompt",
lines=2,
)
variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance")
with gr.Row():
submit_button = gr.Button("Submit", variant="primary")
with gr.Column(scale=3):
model3d = gr.Model3D(
label="Output", height="45em", interactive=False
)
submit_button.click(
handle_text_prompt,
inputs=[
input_text_box,
variance
],
outputs=[
model3d
]
)
return interface
if __name__=="__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--config_path",
type=str,
help="Path to the config file",
default="cube/cube3d/configs/open_model.yaml",
)
parser.add_argument(
"--gpt_ckpt_path",
type=str,
help="Path to the gpt ckpt path",
default="model_weights/shape_gpt.safetensors",
)
parser.add_argument(
"--shape_ckpt_path",
type=str,
help="Path to the shape ckpt path",
default="model_weights/shape_tokenizer.safetensors",
)
parser.add_argument(
"--save_dir",
type=str,
default="gradio_save_dir",
)
args = parser.parse_args()
snapshot_download(
repo_id="Roblox/cube3d-v0.1",
local_dir="./model_weights"
)
config_path = args.config_path
gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
engine_fast = EngineFast(
config_path,
gpt_ckpt_path,
shape_ckpt_path,
device=torch.device("cuda"),
)
GLOBAL_STATE["engine_fast"] = engine_fast
GLOBAL_STATE["SAVE_DIR"] = args.save_dir
os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)
demo = build_interface()
demo.queue(default_concurrency_limit=1)
demo.launch()