File size: 4,203 Bytes
be383f9
6d250b3
be383f9
 
 
 
ad63d32
ef1530a
abdc26c
 
ef1530a
ff14b75
 
 
 
0b2312f
0c04a17
be383f9
 
 
7e70d08
6d250b3
 
be383f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6cf650b
f6a2f50
 
 
be383f9
 
 
a68b44f
be383f9
 
 
 
 
 
 
 
 
 
 
f6a2f50
be383f9
 
 
 
 
 
 
 
 
 
 
466d1ab
be383f9
 
 
 
 
 
 
 
 
 
f6a2f50
 
be383f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e70d08
 
 
 
4586de6
7e70d08
 
f641176
7e70d08
 
 
be383f9
 
 
 
 
 
 
de83d86
be383f9
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
import argparse
import gradio as gr
import os
import torch
import trimesh
import sys
from pathlib import Path

pathdir = Path(__file__).parent / 'cube'
sys.path.append(pathdir.as_posix())

# print(__file__)
# print(os.listdir())
# print(os.listdir('cube'))
# print(pathdir.as_posix())

from cube3d.inference.engine import EngineFast, Engine
from pathlib import Path
import uuid
import shutil
from huggingface_hub import snapshot_download


GLOBAL_STATE = {}

def gen_save_folder(max_size=200):
    os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)

    dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]

    if len(dirs) >= max_size:
        oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
        shutil.rmtree(oldest_dir)
        print(f"Removed the oldest folder: {oldest_dir}")

    new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
    os.makedirs(new_folder, exist_ok=True)
    print(f"Created new folder: {new_folder}")

    return new_folder

def handle_text_prompt(input_prompt, variance = 0):
    print(f"prompt: {input_prompt}, variance: {variance}")
    top_p = None if variance == 0 else (100 - variance) / 100.0
    mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0, top_p=top_p)
    # save output
    vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
    save_folder = gen_save_folder()
    output_path = os.path.join(save_folder, "output.glb")
    trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
    return output_path

def build_interface():
    """Build UI for gradio app
    """
    title = "Cube 3D"
    with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
        gr.Markdown(
            f"""
            # {title}
            # Check out our [Github](https://github.com/Roblox/cube) to try it on your own machine!
            """
        )

        with gr.Row():
            with gr.Column(scale=2):
                with gr.Group():
                    input_text_box = gr.Textbox(
                        value=None,
                        label="Prompt",
                        lines=2,
                    )
                    variance = gr.Slider(minimum=0, maximum=99, step=1, value=0, label="Variance")
                with gr.Row():
                    submit_button = gr.Button("Submit", variant="primary")
            with gr.Column(scale=3):
                model3d = gr.Model3D(
                    label="Output", height="45em", interactive=False
                )
    
        submit_button.click(
            handle_text_prompt,
            inputs=[
                input_text_box,
                variance
            ],
            outputs=[
                model3d
            ]
        )
                
    return interface

if __name__=="__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        help="Path to the config file",
        default="cube/cube3d/configs/open_model.yaml",
    )
    parser.add_argument(
        "--gpt_ckpt_path",
        type=str,
        help="Path to the gpt ckpt path",
        default="model_weights/shape_gpt.safetensors",
    )
    parser.add_argument(
        "--shape_ckpt_path",
        type=str,
        help="Path to the shape ckpt path",
        default="model_weights/shape_tokenizer.safetensors",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="gradio_save_dir",
    )

    args = parser.parse_args()
    snapshot_download(
        repo_id="Roblox/cube3d-v0.1",
        local_dir="./model_weights"
    )
    config_path = args.config_path
    gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
    shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
    engine_fast = EngineFast(
        config_path,
        gpt_ckpt_path, 
        shape_ckpt_path,
        device=torch.device("cuda"),
    )
    GLOBAL_STATE["engine_fast"] = engine_fast
    GLOBAL_STATE["SAVE_DIR"] = args.save_dir
    os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)

    demo = build_interface()
    demo.queue(default_concurrency_limit=1)
    demo.launch()