|
import spaces |
|
import gradio as gr |
|
import os |
|
from Anymate.args import ui_args, anymate_args |
|
from Anymate.utils.ui_utils import process_input, vis_joint, vis_connectivity, vis_skinning, vis_all, prepare_blender_file |
|
from Anymate.utils.ui_utils import get_result_joint, get_result_connectivity, get_result_skinning |
|
|
|
from Anymate.utils.utils import load_checkpoint |
|
|
|
|
|
if not (os.path.exists(ui_args.checkpoint_joint) and |
|
os.path.exists(ui_args.checkpoint_conn) and |
|
os.path.exists(ui_args.checkpoint_skin)): |
|
print("Missing checkpoints, downloading them...") |
|
os.system("bash Anymate/get_checkpoints.sh") |
|
|
|
model_joint = load_checkpoint(ui_args.checkpoint_joint, 'cpu', anymate_args.num_joints).to(anymate_args.device) |
|
model_connectivity = load_checkpoint(ui_args.checkpoint_conn, 'cpu', anymate_args.num_joints).to(anymate_args.device) |
|
model_skinning = load_checkpoint(ui_args.checkpoint_skin, 'cpu', anymate_args.num_joints).to(anymate_args.device) |
|
|
|
@spaces.GPU |
|
def get_all_results(mesh_file, pc, eps=0.03, min_samples=1): |
|
|
|
joints = get_result_joint(mesh_file, model_joint, pc, eps=eps, min_samples=min_samples) |
|
conns = get_result_connectivity(mesh_file, model_connectivity, pc, joints) |
|
skins = get_result_skinning(mesh_file, model_skinning, pc, joints, conns) |
|
print("Finish Inference") |
|
return |
|
|
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown(""" |
|
# Anymate: Auto-rigging 3D Objects |
|
[Project](https://anymate3d.github.io/) |
|
""") |
|
|
|
pc = gr.State(value=None) |
|
normalized_mesh_file = gr.State(value=None) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
|
|
gr.Markdown("### Input") |
|
mesh_input = gr.Model3D(label="Input 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0]) |
|
|
|
|
|
gr.Markdown("### Sample Objects") |
|
sample_objects_dir = './samples' |
|
sample_objects = [os.path.join(sample_objects_dir, f) for f in os.listdir(sample_objects_dir) |
|
if f.endswith('.obj') and os.path.isfile(os.path.join(sample_objects_dir, f))] |
|
sample_objects.sort() |
|
|
|
sample_dropdown = gr.Dropdown( |
|
label="Select Sample Object", |
|
choices=sample_objects, |
|
interactive=True, |
|
value=sample_objects[0] |
|
) |
|
|
|
load_sample_btn = gr.Button("Load Sample") |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown("### Output (wireframe display mode)") |
|
mesh_output = gr.Model3D(label="Output 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0], display_mode="wireframe") |
|
|
|
with gr.Column(): |
|
|
|
gr.Markdown("### (solid display mode & blender file)") |
|
mesh_output2 = gr.Model3D(label="Output 3D Mesh", clear_color=[0.0, 0.0, 0.0, 0.0], display_mode="solid") |
|
|
|
blender_file = gr.File(label="Output Blender File", scale=1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with gr.Row(): |
|
|
|
process_all_btn = gr.Button("Run all models", scale=1) |
|
|
|
|
|
|
|
|
|
eps = gr.Number(label="Epsilon", value=0.03, interactive=True, info="Controls the maximum distance between joints in a cluster") |
|
min_samples = gr.Number(label="Min Samples", value=1, interactive=True, info="Minimum number of joints required to form a cluster") |
|
|
|
mesh_input.change( |
|
process_input, |
|
inputs=mesh_input, |
|
outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc] |
|
) |
|
|
|
load_sample_btn.click( |
|
fn=lambda sample_path: sample_path if sample_path else None, |
|
inputs=[sample_dropdown], |
|
outputs=[mesh_input] |
|
).then( |
|
process_input, |
|
inputs=mesh_input, |
|
outputs=[normalized_mesh_file, mesh_output, mesh_output2, blender_file, pc] |
|
) |
|
|
|
normalized_mesh_file.change( |
|
lambda x: x, |
|
inputs=normalized_mesh_file, |
|
outputs=mesh_input |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
process_all_btn.click( |
|
get_all_results, |
|
inputs=[normalized_mesh_file, pc, eps, min_samples], |
|
outputs=[] |
|
).then( |
|
vis_all, |
|
inputs=[normalized_mesh_file], |
|
outputs=[mesh_output, mesh_output2] |
|
).then( |
|
prepare_blender_file, |
|
inputs=[normalized_mesh_file], |
|
outputs=blender_file |
|
) |
|
|
|
if __name__ == "__main__": |
|
demo.launch() |
|
|