File size: 3,675 Bytes
be383f9
6d250b3
be383f9
 
 
 
7e70d08
be383f9
 
 
 
7e70d08
6d250b3
 
be383f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7836010
be383f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
7e70d08
 
 
 
 
 
 
be383f9
7e70d08
 
 
be383f9
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import argparse
import gradio as gr
import os
import torch
import trimesh
import sys
sys.path.append("cube")
from cube3d.inference.engine import EngineFast
from pathlib import Path
import uuid
import shutil
from huggingface_hub import snapshot_download


GLOBAL_STATE = {}

def gen_save_folder(max_size=200):
    os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)

    dirs = [f for f in Path(GLOBAL_STATE["SAVE_DIR"]).iterdir() if f.is_dir()]

    if len(dirs) >= max_size:
        oldest_dir = min(dirs, key=lambda x: x.stat().st_ctime)
        shutil.rmtree(oldest_dir)
        print(f"Removed the oldest folder: {oldest_dir}")

    new_folder = os.path.join(GLOBAL_STATE["SAVE_DIR"], str(uuid.uuid4()))
    os.makedirs(new_folder, exist_ok=True)
    print(f"Created new folder: {new_folder}")

    return new_folder

def handle_text_prompt(input_prompt):
    mesh_v_f = GLOBAL_STATE["engine_fast"].t2s([input_prompt], use_kv_cache=True, resolution_base=8.0)
    # save output
    vertices, faces = mesh_v_f[0][0], mesh_v_f[0][1]
    save_folder = gen_save_folder()
    output_path = os.path.join(save_folder, "output.obj")
    trimesh.Trimesh(vertices=vertices, faces=faces).export(output_path)
    return output_path

def build_interface():
    """Build UI for gradio app
    """
    title = "Cube 3D"
    with gr.Blocks(theme=gr.themes.Soft(), title=title, fill_width=True) as interface:
        gr.Markdown(
            f"""
            # {title}
            """
        )

        with gr.Row():
            with gr.Column(scale=2):
                with gr.Group():
                    input_text_box = gr.Textbox(
                        value=None,
                        label="Prompt",
                        lines=2,
                    )
                with gr.Row():
                    submit_button = gr.Button("Submit", variant="primary")
            with gr.Column(scale=3):
                model3d = gr.Model3D(
                    label="Output", height="45em", interactive=False
                )
    
        submit_button.click(
            handle_text_prompt,
            inputs=[
                input_text_box
            ],
            outputs=[
                model3d
            ]
        )
                
    return interface

if __name__=="__main__":

    parser = argparse.ArgumentParser()
    parser.add_argument(
        "--config_path",
        type=str,
        help="Path to the config file",
        default="cube/cube3d/configs/open_model.yaml",
    )
    parser.add_argument(
        "--gpt_ckpt_path",
        type=str,
        help="Path to the gpt ckpt path",
        default="model_weights/shape_gpt.safetensors",
    )
    parser.add_argument(
        "--shape_ckpt_path",
        type=str,
        help="Path to the shape ckpt path",
        default="model_weights/shape_tokenizer.safetensors",
    )
    parser.add_argument(
        "--save_dir",
        type=str,
        default="gradio_save_dir",
    )

    args = parser.parse_args()
    snapshot_download(
        repo_id="Roblox/cube3d-v0.1",
        local_dir="./model_weights"
    )
    config_path = "./model_weights/shape_tokenizer.safetensors"
    gpt_ckpt_path = "./model_weights/shape_gpt.safetensors"
    shape_ckpt_path = "./model_weights/shape_tokenizer.safetensors"
    engine_fast = EngineFast(
        config_path,
        gpt_ckpt_path, 
        shape_ckpt_path,
        device=torch.device("cuda"),
    )
    GLOBAL_STATE["engine_fast"] = engine_fast
    GLOBAL_STATE["SAVE_DIR"] = args.save_dir
    os.makedirs(GLOBAL_STATE["SAVE_DIR"], exist_ok=True)

    demo = build_interface()
    demo.queue(default_concurrency_limit=None)
    demo.launch()