File size: 8,062 Bytes
45a206f
cf4dedb
744eb4e
45a206f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cf4dedb
744eb4e
 
 
28a2da9
744eb4e
 
 
 
 
45a206f
 
 
744eb4e
45a206f
 
 
744eb4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45a206f
 
744eb4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45a206f
 
 
 
 
 
 
 
 
 
 
744eb4e
 
 
 
 
 
45a206f
 
 
 
 
 
 
 
 
 
 
744eb4e
 
 
 
 
 
45a206f
 
 
 
 
 
 
 
 
 
 
744eb4e
 
 
 
 
 
 
45a206f
744eb4e
 
 
 
3d67d08
 
 
744eb4e
 
 
 
45a206f
744eb4e
 
 
 
 
 
 
 
 
45a206f
744eb4e
 
 
 
 
 
 
 
45a206f
 
 
 
744eb4e
 
45a206f
 
 
 
744eb4e
 
45a206f
 
 
 
744eb4e
 
45a206f
 
 
 
744eb4e
cf4dedb
45a206f
 
 
 
 
 
 
 
 
 
 
 
 
 
cf4dedb
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
import spaces
import gradio as gr
import os
from Anymate.args import ui_args, anymate_args
from Anymate.utils.ui_utils import process_input, vis_joint, vis_connectivity, vis_skinning, vis_all, prepare_blender_file
from Anymate.utils.ui_utils import get_result_joint, get_result_connectivity, get_result_skinning

from Anymate.utils.utils import load_checkpoint

# Check if checkpoints exist, if not download them
if not (os.path.exists(ui_args.checkpoint_joint) and 
        os.path.exists(ui_args.checkpoint_conn) and 
        os.path.exists(ui_args.checkpoint_skin)):
    print("Missing checkpoints, downloading them...")
    os.system("bash Anymate/get_checkpoints.sh")

model_joint = load_checkpoint(ui_args.checkpoint_joint, 'cpu', anymate_args.num_joints).to(anymate_args.device)
model_connectivity = load_checkpoint(ui_args.checkpoint_conn, 'cpu', anymate_args.num_joints).to(anymate_args.device)
model_skinning = load_checkpoint(ui_args.checkpoint_skin, 'cpu', anymate_args.num_joints).to(anymate_args.device)

@spaces.GPU
def get_all_results(mesh_file, pc, eps=0.03, min_samples=1):
    # pc = pc.to(anymate_args.device)
    joints = get_result_joint(mesh_file, model_joint, pc, eps=eps, min_samples=min_samples)
    conns = get_result_connectivity(mesh_file, model_connectivity, pc, joints)
    skins = get_result_skinning(mesh_file, model_skinning, pc, joints, conns)
    print("Finish Inference")
    return


with gr.Blocks() as demo:
    gr.Markdown("""
    # Anymate: Auto-rigging 3D Objects
    [Project](https://anymate3d.github.io/)
    """)

    pc = gr.State(value=None)
    normalized_mesh_file = gr.State(value=None)

    # result_joint = gr.State(value=None)
    # result_connectivity = gr.State(value=None)
    # result_skinning = gr.State(value=None)

    # model_joint = gr.State(value=model_joint)
    # model_connectivity = gr.State(value=model_connectivity)
    # model_skinning = gr.State(value=model_skinning)
    
    with gr.Row():
        with gr.Column():
            # Input section
            gr.Markdown("### Input")
            mesh_input = gr.Model3D(label="Input 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0])

            # Sample 3D objects section
            gr.Markdown("### Sample Objects")
            sample_objects_dir = './samples'
            sample_objects = [os.path.join(sample_objects_dir, f) for f in os.listdir(sample_objects_dir) 
                            if f.endswith('.obj') and os.path.isfile(os.path.join(sample_objects_dir, f))]
            sample_objects.sort()
            
            sample_dropdown = gr.Dropdown(
                label="Select Sample Object",
                choices=sample_objects,
                interactive=True,
                value=sample_objects[0]
            )
            
            load_sample_btn = gr.Button("Load Sample")
            
        with gr.Column():
            # Output section
            gr.Markdown("### Output (wireframe display mode)")
            mesh_output = gr.Model3D(label="Output 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0], display_mode="wireframe")

        with gr.Column():
            # Output section
            gr.Markdown("### (solid display mode & blender file)")
            mesh_output2 = gr.Model3D(label="Output 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0], display_mode="solid")
            
            blender_file = gr.File(label="Output Blender File", scale=1)

    # Checkpoint paths
    # joint_models_dir = 'Anymate/checkpoints/joint'
    # joint_models = [os.path.join(joint_models_dir, f) for f in os.listdir(joint_models_dir) 
    #                 if os.path.isfile(os.path.join(joint_models_dir, f))]
    # with gr.Row():
    #     joint_checkpoint = gr.Dropdown(
    #         label="Joint Checkpoint",
    #         choices=joint_models,
    #         value=ui_args.checkpoint_joint,
    #         interactive=True
    #     )
    #     joint_status = gr.Checkbox(label="Joint Model Status", value=False, interactive=False, scale=0.3)
        # with gr.Column():
        #     with gr.Row():
        #         load_joint_btn = gr.Button("Load", scale=0.3)
                
        #     process_joint_btn = gr.Button("Process", scale=0.3)

    # conn_models_dir = 'Anymate/checkpoints/conn' 
    # conn_models = [os.path.join(conn_models_dir, f) for f in os.listdir(conn_models_dir)
    #                 if os.path.isfile(os.path.join(conn_models_dir, f))]
    # with gr.Row():
    #     conn_checkpoint = gr.Dropdown(
    #         label="Connection Checkpoint",
    #         choices=conn_models,
    #         value=ui_args.checkpoint_conn,
    #         interactive=True
    #     )
    #     conn_status = gr.Checkbox(label="Connectivity Model Status", value=False, interactive=False, scale=0.3)
        # with gr.Column():
        #     with gr.Row():
        #         load_conn_btn = gr.Button("Load", scale=0.3)
                
        #     process_conn_btn = gr.Button("Process", scale=0.3)

    # skin_models_dir = 'Anymate/checkpoints/skin'
    # skin_models = [os.path.join(skin_models_dir, f) for f in os.listdir(skin_models_dir)
    #                 if os.path.isfile(os.path.join(skin_models_dir, f))]
    # with gr.Row():
    #     skin_checkpoint = gr.Dropdown(
    #         label="Skin Checkpoint", 
    #         choices=skin_models,
    #         value=ui_args.checkpoint_skin,
    #         interactive=True
    #     )
    #     skin_status = gr.Checkbox(label="Skinning Model Status", value=False, interactive=False, scale=0.3)
        # with gr.Column():
        #     with gr.Row():
        #         load_skin_btn = gr.Button("Load", scale=0.3)
               
        #     process_skin_btn = gr.Button("Process", scale=0.3)

    with gr.Row():
        # load_all_btn = gr.Button("Load all models", scale=1)
        process_all_btn = gr.Button("Run all models", scale=1)
        # download_btn = gr.DownloadButton("Blender File Not Ready", scale=0.3)
        # blender_file = gr.File(label="Blender File", scale=1)

    # Parameters for DBSCAN clustering algorithm used to adjust joint clustering
    eps = gr.Number(label="Epsilon", value=0.03, interactive=True, info="Controls the maximum distance between joints in a cluster")
    min_samples = gr.Number(label="Min Samples", value=1, interactive=True, info="Minimum number of joints required to form a cluster")
    
    mesh_input.change(
        process_input,
        inputs=mesh_input,
        outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc]
    )

    load_sample_btn.click(
        fn=lambda sample_path: sample_path if sample_path else None,
        inputs=[sample_dropdown],
        outputs=[mesh_input]
    ).then(
        process_input,
        inputs=mesh_input,
        outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc]
    )

    normalized_mesh_file.change(
        lambda x: x,
        inputs=normalized_mesh_file,
        outputs=mesh_input
    )

    # result_joint.change(
    #     vis_joint,
    #     inputs=[normalized_mesh_file, result_joint],
    #     outputs=[mesh_output, mesh_output2]
    # )

    # result_connectivity.change(
    #     vis_connectivity,
    #     inputs=[normalized_mesh_file, result_joint, result_connectivity],
    #     outputs=[mesh_output, mesh_output2]
    # )

    # result_skinning.change(
    #     vis_skinning,
    #     inputs=[normalized_mesh_file, result_joint, result_connectivity, result_skinning],
    #     outputs=[mesh_output, mesh_output2]
    # )

    # result_skinning.change(
    #     prepare_blender_file,
    #     inputs=[normalized_mesh_file],
    #     outputs=blender_file
    # )

    process_all_btn.click(
        get_all_results,
        inputs=[normalized_mesh_file, pc, eps, min_samples],
        outputs=[]
    ).then(
        vis_all,
        inputs=[normalized_mesh_file],
        outputs=[mesh_output, mesh_output2]
    ).then(
        prepare_blender_file,
        inputs=[normalized_mesh_file],
        outputs=blender_file
    )

if __name__ == "__main__":
    demo.launch()