init
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Anymate/.gitignore +26 -0
- Anymate/__init__.py +0 -0
- Anymate/args.py +22 -0
- Anymate/blender_script.py +747 -0
- Anymate/checkpoints/.gitkeep +0 -0
- Anymate/configs/.gitkeep +0 -0
- Anymate/configs/conn.yaml +40 -0
- Anymate/configs/conn_token.yaml +40 -0
- Anymate/configs/diffusion.yaml +49 -0
- Anymate/configs/diffusion_concat.yaml +46 -0
- Anymate/configs/diffusion_cross.yaml +51 -0
- Anymate/configs/joints.yaml +40 -0
- Anymate/configs/joints_implicit.yaml +40 -0
- Anymate/configs/joints_triplane.yaml +40 -0
- Anymate/configs/skin.yaml +40 -0
- Anymate/configs/skin_multi.yaml +40 -0
- Anymate/dataset.py +62 -0
- Anymate/get_checkpoints.sh +22 -0
- Anymate/get_datasets.sh +12 -0
- Anymate/model.py +360 -0
- Anymate/models/__init__.py +0 -0
- Anymate/models/conn.py +195 -0
- Anymate/models/diffusion.py +483 -0
- Anymate/models/joint.py +282 -0
- Anymate/models/skin.py +309 -0
- Anymate/tmp/.gitkeep +0 -0
- Anymate/utils/dataset_utils.py +129 -0
- Anymate/utils/diffusion_encoder.py +258 -0
- Anymate/utils/diffusion_utils.py +314 -0
- Anymate/utils/eval_utils.py +225 -0
- Anymate/utils/loss_utils.py +56 -0
- Anymate/utils/render_utils.py +1169 -0
- Anymate/utils/train_utils.py +406 -0
- Anymate/utils/ui_utils.py +284 -0
- Anymate/utils/ui_utils_bpy.py +134 -0
- Anymate/utils/utils.py +77 -0
- Anymate/utils/vol_utils.py +135 -0
- Render.py +17 -0
- ThirdParty/PointLLM/.gitignore +12 -0
- ThirdParty/PointLLM/README.md +353 -0
- ThirdParty/PointLLM/__init__.py +0 -0
- ThirdParty/PointLLM/pointllm/__init__.py +1 -0
- ThirdParty/PointLLM/pointllm/conversation.py +375 -0
- ThirdParty/PointLLM/pointllm/data/__init__.py +3 -0
- ThirdParty/PointLLM/pointllm/data/modelnet.py +147 -0
- ThirdParty/PointLLM/pointllm/data/modelnet_config/ModelNet40.yaml +8 -0
- ThirdParty/PointLLM/pointllm/data/object_point_dataset.py +250 -0
- ThirdParty/PointLLM/pointllm/data/utils.py +236 -0
- ThirdParty/PointLLM/pointllm/eval/PointLLM_chat.py +157 -0
- ThirdParty/PointLLM/pointllm/eval/chat_gradio.py +394 -0
Anymate/.gitignore
ADDED
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.pt
|
3 |
+
*.tar
|
4 |
+
*.tar
|
5 |
+
*.txt
|
6 |
+
*.glb*
|
7 |
+
*.obj
|
8 |
+
*.ckpt
|
9 |
+
*.blend
|
10 |
+
*.blend1
|
11 |
+
test_*
|
12 |
+
|
13 |
+
blender-*
|
14 |
+
*.json*
|
15 |
+
*.glb
|
16 |
+
*.gltf
|
17 |
+
*.fbx
|
18 |
+
*.FBX
|
19 |
+
*.dae
|
20 |
+
*.obj
|
21 |
+
*.mtl
|
22 |
+
*.binvox
|
23 |
+
*.csv
|
24 |
+
*.tga
|
25 |
+
*.png
|
26 |
+
*.jpg
|
Anymate/__init__.py
ADDED
File without changes
|
Anymate/args.py
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
class AnymateArgs:
|
2 |
+
def __init__(self):
|
3 |
+
# self.encoder = "miche"
|
4 |
+
# self.decoder = "transformer_latent"
|
5 |
+
# self.dataset = "train"
|
6 |
+
# self.run_name = "miche-transformer_latent-train-8gpu-finetune"
|
7 |
+
self.checkpoint_joint = "Anymate/checkpoints/joint/bert-transformer_latent-train-8gpu-finetune.pth.tar"
|
8 |
+
self.checkpoint_conn = "Anymate/checkpoints/conn/bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar"
|
9 |
+
self.checkpoint_skin = "Anymate/checkpoints/skin/bert-attendjoints_combine-train-8gpu-finetune.pth.tar"
|
10 |
+
|
11 |
+
self.device = "cuda"
|
12 |
+
self.num_joints = 96
|
13 |
+
|
14 |
+
|
15 |
+
class UIArgs:
|
16 |
+
def __init__(self):
|
17 |
+
self.checkpoint_joint = "Anymate/checkpoints/joint/bert-transformer_latent-train-8gpu-finetune.pth.tar"
|
18 |
+
self.checkpoint_conn = "Anymate/checkpoints/conn/bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar"
|
19 |
+
self.checkpoint_skin = "Anymate/checkpoints/skin/bert-attendjoints_combine-train-8gpu-finetune.pth.tar"
|
20 |
+
|
21 |
+
ui_args = UIArgs()
|
22 |
+
anymate_args = AnymateArgs()
|
Anymate/blender_script.py
ADDED
@@ -0,0 +1,747 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bpy
|
2 |
+
import mathutils
|
3 |
+
from mathutils import Vector, Matrix
|
4 |
+
|
5 |
+
import os
|
6 |
+
import sys
|
7 |
+
import random
|
8 |
+
import numpy as np
|
9 |
+
import json
|
10 |
+
import argparse
|
11 |
+
|
12 |
+
|
13 |
+
IMPORT_FUNCTIONS = {
|
14 |
+
"obj": bpy.ops.wm.obj_import,
|
15 |
+
"glb": bpy.ops.import_scene.gltf,
|
16 |
+
"gltf": bpy.ops.import_scene.gltf,
|
17 |
+
"usd": bpy.ops.import_scene.usd,
|
18 |
+
"fbx": bpy.ops.import_scene.fbx,
|
19 |
+
"stl": bpy.ops.import_mesh.stl,
|
20 |
+
"usda": bpy.ops.import_scene.usda,
|
21 |
+
"dae": bpy.ops.wm.collada_import,
|
22 |
+
"ply": bpy.ops.import_mesh.ply,
|
23 |
+
"abc": bpy.ops.wm.alembic_import,
|
24 |
+
"blend": bpy.ops.wm.append,
|
25 |
+
}
|
26 |
+
|
27 |
+
def load_object(object_path: str) -> None:
|
28 |
+
"""Loads a model with a supported file extension into the scene.
|
29 |
+
|
30 |
+
Args:
|
31 |
+
object_path (str): Path to the model file.
|
32 |
+
|
33 |
+
Raises:
|
34 |
+
ValueError: If the file extension is not supported.
|
35 |
+
|
36 |
+
Returns:
|
37 |
+
None
|
38 |
+
"""
|
39 |
+
file_extension = object_path.split(".")[-1].lower()
|
40 |
+
if file_extension is None:
|
41 |
+
raise ValueError(f"Unsupported file type: {object_path}")
|
42 |
+
|
43 |
+
# load from existing import functions
|
44 |
+
import_function = IMPORT_FUNCTIONS[file_extension]
|
45 |
+
|
46 |
+
if file_extension == "blend":
|
47 |
+
import_function(directory=object_path, link=False)
|
48 |
+
elif file_extension in {"glb", "gltf"}:
|
49 |
+
import_function(filepath=object_path, merge_vertices=True)
|
50 |
+
else:
|
51 |
+
import_function(filepath=object_path)
|
52 |
+
|
53 |
+
####################### save json ################################
|
54 |
+
def save_json(output_path, mesh_obj, armature_obj, extra=None, arm_name=False):
|
55 |
+
# makedirs output_path
|
56 |
+
os.makedirs(output_path, exist_ok=True)
|
57 |
+
|
58 |
+
# start retrieve the information of mesh, skining and rigging
|
59 |
+
|
60 |
+
#1. retrieve the information of rigging, save the world matrix of the amature object
|
61 |
+
total_armature_info = {}
|
62 |
+
for obj in armature_obj:
|
63 |
+
# depsgraph = bpy.context.evaluated_depsgraph_get()
|
64 |
+
# obj = obj.evaluated_get(depsgraph)
|
65 |
+
armature_info = {}
|
66 |
+
armature_info["world_matrix"] = [list(row) for row in obj.matrix_world.copy()]
|
67 |
+
translation = obj.matrix_world.translation
|
68 |
+
for bone in obj.pose.bones:
|
69 |
+
bone_info = {}
|
70 |
+
bone_info["head_local"] = list(bone.head.copy())
|
71 |
+
bone_info["head_world"] = list((obj.matrix_world.to_3x3() @ bone.head+translation).copy())
|
72 |
+
# bone_info["matrix_local"] = [list(row) for row in bone.matrix_local.copy()]
|
73 |
+
bone_info["tail_local"] = list(bone.tail.copy())
|
74 |
+
bone_info["tail_world"] = list((obj.matrix_world.to_3x3() @ bone.tail+translation).copy())
|
75 |
+
|
76 |
+
if bone.parent:
|
77 |
+
bone_info["parent"] = bone.parent.name.replace(" ", "_")
|
78 |
+
if arm_name:
|
79 |
+
bone_info["parent"] = obj.name + "--" + bone_info["parent"]
|
80 |
+
else:
|
81 |
+
bone_info["parent"] = None
|
82 |
+
bone_info["children"] = []
|
83 |
+
if bone.children:
|
84 |
+
for child in bone.children:
|
85 |
+
if arm_name:
|
86 |
+
bone_info["children"].append(obj.name + "--" + child.name.replace(" ", "_"))
|
87 |
+
else:
|
88 |
+
bone_info["children"].append(child.name.replace(" ", "_"))
|
89 |
+
bone_name = bone.name.replace(" ", "_")
|
90 |
+
if arm_name:
|
91 |
+
bone_name = obj.name + "--" + bone_name
|
92 |
+
armature_info[bone_name] = bone_info
|
93 |
+
obj_name = obj.name.replace(" ", "_")
|
94 |
+
total_armature_info[obj.name] = armature_info
|
95 |
+
|
96 |
+
|
97 |
+
#2. retrieve the informatioon of skining
|
98 |
+
total_skinning_info = {}
|
99 |
+
for obj in mesh_obj:
|
100 |
+
vertex_groups = obj.vertex_groups
|
101 |
+
# if not vertex_groups:
|
102 |
+
# continue
|
103 |
+
# for group in vertex_groups:
|
104 |
+
skinning_info = {}
|
105 |
+
skinning_info["world_matrix"] = [list(row) for row in obj.matrix_world.copy()]
|
106 |
+
weight_info = []
|
107 |
+
for vertex in obj.data.vertices:
|
108 |
+
vertex_info = {}
|
109 |
+
for group in vertex.groups:
|
110 |
+
name = vertex_groups[group.group].name
|
111 |
+
name = name.replace(" ", "_")
|
112 |
+
if arm_name:
|
113 |
+
arm_modifier = [modifier for modifier in obj.modifiers if modifier.type == 'ARMATURE']
|
114 |
+
assert(len(arm_modifier) == 1)
|
115 |
+
name = arm_modifier[0].object.name + "--" + name
|
116 |
+
weight = group.weight
|
117 |
+
vertex_info[name] = weight
|
118 |
+
weight_info.append(vertex_info)
|
119 |
+
skinning_info["weight"] = weight_info
|
120 |
+
obj_name = obj.name.replace(" ", "_")
|
121 |
+
total_skinning_info[obj_name]=skinning_info
|
122 |
+
|
123 |
+
|
124 |
+
rigging_file_path = os.path.join(output_path, "rigging.json")
|
125 |
+
if extra:
|
126 |
+
rigging_file_path = rigging_file_path.replace("rigging.json", f'rigging_{extra}.json')
|
127 |
+
with open(rigging_file_path, "w") as f:
|
128 |
+
json.dump(total_armature_info, f, indent = 2)
|
129 |
+
|
130 |
+
skining_file_path = os.path.join(output_path, "skining.json")
|
131 |
+
if extra:
|
132 |
+
skining_file_path = skining_file_path.replace("skining.json", f'skining_{extra}.json')
|
133 |
+
with open(skining_file_path, "w") as f:
|
134 |
+
json.dump(total_skinning_info, f , indent = 2)
|
135 |
+
|
136 |
+
|
137 |
+
return rigging_file_path
|
138 |
+
|
139 |
+
|
140 |
+
def apply_skinning_weights(json_file):
|
141 |
+
|
142 |
+
with open(json_file, "r") as f:
|
143 |
+
skinning_data = json.load(f)
|
144 |
+
|
145 |
+
armature_obj = bpy.data.objects.get("Armature")
|
146 |
+
if not armature_obj:
|
147 |
+
print("Error: Armature object 'Armature' not found.")
|
148 |
+
return
|
149 |
+
|
150 |
+
# 将所有网格对象放置在骨骼对象的子集中
|
151 |
+
count = 0
|
152 |
+
for obj in bpy.context.scene.objects:
|
153 |
+
if obj.type == 'MESH':
|
154 |
+
obj.parent = armature_obj
|
155 |
+
count += 1
|
156 |
+
|
157 |
+
print("total mesh count:", count)
|
158 |
+
|
159 |
+
for obj in bpy.context.scene.objects:
|
160 |
+
vertex_index = 0
|
161 |
+
if obj.type == 'MESH':
|
162 |
+
mesh_name = obj.name
|
163 |
+
if mesh_name in skinning_data:
|
164 |
+
skinning_info = skinning_data[mesh_name]
|
165 |
+
if "weight" in skinning_info:
|
166 |
+
print("Applying skinning data for mesh:", mesh_name)
|
167 |
+
vertex_index = 0
|
168 |
+
for vertex_weight in skinning_info["weight"]:
|
169 |
+
for bone_name, weight_value in vertex_weight.items():
|
170 |
+
vertex_group = obj.vertex_groups.get(bone_name)
|
171 |
+
if vertex_group is None:
|
172 |
+
vertex_group = obj.vertex_groups.new(name=bone_name)
|
173 |
+
print("Vertex group created:", bone_name)
|
174 |
+
vertex_group.add([vertex_index], weight_value, 'REPLACE')
|
175 |
+
vertex_index += 1
|
176 |
+
else:
|
177 |
+
print("No skinning data found for mesh:", mesh_name)
|
178 |
+
for obj in bpy.context.scene.objects:
|
179 |
+
if obj.type == 'MESH':
|
180 |
+
modifier = obj.modifiers.new(name="Armature", type='ARMATURE')
|
181 |
+
modifier.object = armature_obj
|
182 |
+
modifier.use_vertex_groups = True
|
183 |
+
print("Armature modifier added to mesh:", obj.name)
|
184 |
+
|
185 |
+
def reload_rigging(rigging_file_path):
|
186 |
+
with open(rigging_file_path, "r") as f:
|
187 |
+
total_armature_info = json.load(f)
|
188 |
+
|
189 |
+
bpy.ops.object.armature_add()
|
190 |
+
armature_obj = bpy.context.object
|
191 |
+
armature_obj.name = "Armature"
|
192 |
+
|
193 |
+
bpy.ops.object.mode_set(mode='EDIT')
|
194 |
+
bpy.ops.armature.select_all(action='SELECT')
|
195 |
+
bpy.ops.armature.delete()
|
196 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
197 |
+
bpy.ops.object.mode_set(mode='EDIT')
|
198 |
+
|
199 |
+
world_matrix = mathutils.Matrix([[1, 0, 0, 0],
|
200 |
+
[0, 1, 0, 0],
|
201 |
+
[0, 0, 1, 0],
|
202 |
+
[0, 0, 0, 1]])
|
203 |
+
armature_obj.matrix_world = world_matrix
|
204 |
+
|
205 |
+
for armature_name, armature_info in total_armature_info.items():
|
206 |
+
for bone_name, bone_info in armature_info.items():
|
207 |
+
if bone_name == "world_matrix":
|
208 |
+
continue
|
209 |
+
bone = armature_obj.data.edit_bones.new(bone_name)
|
210 |
+
bone.head = bone_info["head_world"]
|
211 |
+
bone.tail = bone_info["tail_world"]
|
212 |
+
|
213 |
+
for bone_name, bone_info in armature_info.items():
|
214 |
+
if bone_name == "world_matrix":
|
215 |
+
continue
|
216 |
+
bone = armature_obj.data.edit_bones[bone_name]
|
217 |
+
parent_name = bone_info["parent"]
|
218 |
+
if parent_name:
|
219 |
+
parent_bone = armature_obj.data.edit_bones[parent_name]
|
220 |
+
bone.parent = parent_bone
|
221 |
+
edit_len = len(armature_obj.data.edit_bones.keys())
|
222 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
223 |
+
bone_len = len(armature_obj.data.bones.keys())
|
224 |
+
assert(edit_len == bone_len, "bone number not match!" + str(edit_len) + " " + str(bone_len))
|
225 |
+
bpy.ops.object.select_all(action='DESELECT')
|
226 |
+
armature_obj.select_set(True)
|
227 |
+
bpy.context.view_layer.objects.active = armature_obj
|
228 |
+
print("Rigging information has been reloaded!")
|
229 |
+
|
230 |
+
############################# reload json ################################
|
231 |
+
def reload_json(folder_path, version=0, export = None):
|
232 |
+
bpy.ops.wm.read_homefile(use_empty=True)
|
233 |
+
if version == 0:
|
234 |
+
obj_path = os.path.join(folder_path, "object.obj")
|
235 |
+
skinning_file_path = os.path.join(folder_path, "skining.json")
|
236 |
+
rigging_file_path = os.path.join(folder_path, "rigging.json")
|
237 |
+
elif version == 1:
|
238 |
+
obj_path = os.path.join(folder_path, "join.obj")
|
239 |
+
skinning_file_path = os.path.join(folder_path, "skining_norig.json")
|
240 |
+
rigging_file_path = os.path.join(folder_path, "rigging_norig.json")
|
241 |
+
elif version == 2:
|
242 |
+
obj_path = os.path.join(folder_path, "join.obj")
|
243 |
+
skinning_file_path = os.path.join(folder_path, "skining_norig2.json")
|
244 |
+
rigging_file_path = os.path.join(folder_path, "rigging_norig2.json")
|
245 |
+
# import_obj(obj_path)
|
246 |
+
load_object(obj_path)
|
247 |
+
reload_rigging(rigging_file_path)
|
248 |
+
apply_skinning_weights(skinning_file_path)
|
249 |
+
if export:
|
250 |
+
bpy.ops.wm.save_as_mainfile(filepath=export)
|
251 |
+
print("Done!")
|
252 |
+
|
253 |
+
|
254 |
+
def reset_scene() -> None:
|
255 |
+
"""Resets the scene to a clean state.
|
256 |
+
|
257 |
+
Returns:
|
258 |
+
None
|
259 |
+
"""
|
260 |
+
# delete everything that isn't part of a camera or a light
|
261 |
+
for obj in bpy.data.objects:
|
262 |
+
if obj.type not in {"CAMERA", "LIGHT"}:
|
263 |
+
bpy.data.objects.remove(obj, do_unlink=True)
|
264 |
+
|
265 |
+
# delete all the materials
|
266 |
+
for material in bpy.data.materials:
|
267 |
+
bpy.data.materials.remove(material, do_unlink=True)
|
268 |
+
|
269 |
+
# delete all the textures
|
270 |
+
for texture in bpy.data.textures:
|
271 |
+
bpy.data.textures.remove(texture, do_unlink=True)
|
272 |
+
|
273 |
+
# delete all the images
|
274 |
+
for image in bpy.data.images:
|
275 |
+
bpy.data.images.remove(image, do_unlink=True)
|
276 |
+
|
277 |
+
|
278 |
+
def save_mesh(path, mtl=False, obj_path=None):
|
279 |
+
if mtl:
|
280 |
+
# save the blend file
|
281 |
+
bpy.ops.wm.save_as_mainfile(filepath=obj_path + '/object.blend')
|
282 |
+
# reopen the blend file
|
283 |
+
bpy.ops.wm.open_mainfile(filepath=obj_path + '/object.blend')
|
284 |
+
# unpack all the materials and textures to obj_path
|
285 |
+
bpy.ops.file.unpack_all(method='WRITE_LOCAL')
|
286 |
+
# save to .obj without material
|
287 |
+
bpy.ops.wm.obj_export(filepath=path, export_materials=mtl, export_uv=mtl, export_triangulated_mesh=True)
|
288 |
+
|
289 |
+
|
290 |
+
def get_root_obj(obj):
|
291 |
+
if not obj.parent:
|
292 |
+
return obj
|
293 |
+
return get_root_obj(obj.parent)
|
294 |
+
|
295 |
+
def normalize(objs):
|
296 |
+
# bpy.ops.object.select_all(action='DESELECT')
|
297 |
+
# # select objs and join them
|
298 |
+
# for obj in objs:
|
299 |
+
# obj.select_set(True)
|
300 |
+
# bpy.context.view_layer.objects.active = objs[0]
|
301 |
+
# name_join = objs[0].name
|
302 |
+
# bpy.ops.object.join()
|
303 |
+
# obj_join = bpy.context.active_object
|
304 |
+
# print(obj_join.matrix_world)
|
305 |
+
# print(name_join)
|
306 |
+
# assert(name_join == obj_join.name)
|
307 |
+
|
308 |
+
objs_eval = []
|
309 |
+
depsgraph = bpy.context.evaluated_depsgraph_get()
|
310 |
+
for obj in objs:
|
311 |
+
objs_eval.append(obj.evaluated_get(depsgraph))
|
312 |
+
|
313 |
+
vertices = []
|
314 |
+
for obj in objs_eval:
|
315 |
+
for v in obj.data.vertices:
|
316 |
+
vertices.append(obj.matrix_world @ Vector((v.co.x, v.co.y, v.co.z, 1)))
|
317 |
+
|
318 |
+
vertices = np.array(vertices)
|
319 |
+
min_x, min_y, min_z, _ = np.min(vertices, axis=0)
|
320 |
+
max_x, max_y, max_z, _ = np.max(vertices, axis=0)
|
321 |
+
|
322 |
+
# print(min_x, min_y, min_z)
|
323 |
+
# print(max_x, max_y, max_z)
|
324 |
+
|
325 |
+
scale_x = 1 / (max_x - min_x)
|
326 |
+
scale_y = 1 / (max_y - min_y)
|
327 |
+
scale_z = 1 / (max_z - min_z)
|
328 |
+
scale_min = min(scale_x, scale_y, scale_z)
|
329 |
+
|
330 |
+
assert scale_min < 1e6
|
331 |
+
|
332 |
+
translate_x = - (max_x + min_x) / 2 * scale_min
|
333 |
+
translate_y = - (max_y + min_y) / 2 * scale_min
|
334 |
+
translate_z = - min_z * scale_min
|
335 |
+
|
336 |
+
# form transformation matrix
|
337 |
+
trans = Matrix.Translation((translate_x, translate_y, translate_z))
|
338 |
+
|
339 |
+
scale = Matrix.Scale(scale_min, 4, (1, 0, 0)) @ Matrix.Scale(scale_min, 4, (0, 1, 0)) @ Matrix.Scale(scale_min, 4, (0, 0, 1))
|
340 |
+
|
341 |
+
# print(trans, scale)
|
342 |
+
|
343 |
+
|
344 |
+
root = get_root_obj(objs[0])
|
345 |
+
# print(root.name)
|
346 |
+
# print(root.scale)
|
347 |
+
# print(root.location)
|
348 |
+
# print(root.matrix_world)
|
349 |
+
# root.location = mathutils.Vector(root.location) + mathutils.Vector((translate_x, translate_y, translate_z))
|
350 |
+
# root.scale = mathutils.Vector(root.scale) * mathutils.Vector((scale_x, scale_y, scale_z))
|
351 |
+
|
352 |
+
# add the extra transformation to the root object's world matrix
|
353 |
+
root.matrix_world = trans @ scale @ root.matrix_world
|
354 |
+
# print(root.name)
|
355 |
+
# print(root.scale)
|
356 |
+
# print(root.location)
|
357 |
+
# print(root.matrix_world)
|
358 |
+
|
359 |
+
# refresh
|
360 |
+
bpy.context.view_layer.update()
|
361 |
+
|
362 |
+
######### check if its successful
|
363 |
+
# objs_eval = []
|
364 |
+
# depsgraph = bpy.context.evaluated_depsgraph_get()
|
365 |
+
# for obj in objs:
|
366 |
+
# objs_eval.append(obj.evaluated_get(depsgraph))
|
367 |
+
|
368 |
+
# vertices = []
|
369 |
+
# for obj in objs_eval:
|
370 |
+
# for v in obj.data.vertices:
|
371 |
+
# vertices.append(obj.matrix_world @ Vector((v.co.x, v.co.y, v.co.z, 1)))
|
372 |
+
|
373 |
+
# vertices = np.array(vertices)
|
374 |
+
# min_x, min_y, min_z, _ = np.min(vertices, axis=0)
|
375 |
+
# max_x, max_y, max_z, _ = np.max(vertices, axis=0)
|
376 |
+
|
377 |
+
# print(min_x, min_y, min_z)
|
378 |
+
# print(max_x, max_y, max_z)
|
379 |
+
|
380 |
+
def remesh(objs, target=5000):
|
381 |
+
num_v = {}
|
382 |
+
for obj in objs:
|
383 |
+
num_v[obj] = len(obj.data.vertices)
|
384 |
+
|
385 |
+
# sort the num_v dict and make it a dict again
|
386 |
+
num_v_sort = sorted(num_v.items(), key=lambda x: x[1], reverse=True)
|
387 |
+
|
388 |
+
# print(num_v_sort)
|
389 |
+
total_v = sum([num_v[obj] for obj in num_v])
|
390 |
+
|
391 |
+
iters = 0
|
392 |
+
while total_v > target and iters<20:
|
393 |
+
reduce = []
|
394 |
+
for obj, v in num_v_sort:
|
395 |
+
reduce.append(obj)
|
396 |
+
if sum([num_v[oo] for oo in reduce]) > 0.5 * total_v:
|
397 |
+
break
|
398 |
+
for obj in reduce:
|
399 |
+
# check if have shape key
|
400 |
+
if obj.data.shape_keys is not None:
|
401 |
+
# remove obj from num_v
|
402 |
+
num_v.pop(obj)
|
403 |
+
continue
|
404 |
+
|
405 |
+
ratio = 0.5
|
406 |
+
# apply decimate modifier
|
407 |
+
bpy.context.view_layer.objects.active = obj
|
408 |
+
bpy.ops.object.modifier_add(type='DECIMATE')
|
409 |
+
bpy.context.object.modifiers["Decimate"].ratio = ratio
|
410 |
+
bpy.ops.object.modifier_apply(modifier="Decimate")
|
411 |
+
# update num_v
|
412 |
+
num_v[obj] = len(obj.data.vertices)
|
413 |
+
total_v = sum([num_v[obj] for obj in num_v])
|
414 |
+
num_v_sort = sorted(num_v.items(), key=lambda x: x[1], reverse=True)
|
415 |
+
# print(num_v_sort)
|
416 |
+
iters+=1
|
417 |
+
|
418 |
+
|
419 |
+
def get_parents(obj):
|
420 |
+
if not obj.parent:
|
421 |
+
return [obj.name]
|
422 |
+
parents = get_parents(obj.parent)
|
423 |
+
parents.append(obj.name)
|
424 |
+
return parents
|
425 |
+
|
426 |
+
def check(objs, arm):
|
427 |
+
# assert('Sketchfab_model' in bpy.data.objects)
|
428 |
+
|
429 |
+
# root_arm = get_root_obj(arm)
|
430 |
+
# for obj in objs:
|
431 |
+
# if root_arm != get_root_obj(obj):
|
432 |
+
# print('not same root')
|
433 |
+
# return -1
|
434 |
+
# return 1
|
435 |
+
|
436 |
+
# action_num = 0
|
437 |
+
# actions = bpy.data.actions
|
438 |
+
# for act in actions:
|
439 |
+
# action_num += 1
|
440 |
+
# fcurves = act.fcurves
|
441 |
+
# data_paths = []
|
442 |
+
# not_pose = False
|
443 |
+
# for fcurve in fcurves:
|
444 |
+
# data_paths.append(fcurve.data_path)
|
445 |
+
# if not fcurve.data_path.startswith('pose.bones'):
|
446 |
+
# # print(fcurve.data_path)
|
447 |
+
# not_pose = True
|
448 |
+
# # return -1
|
449 |
+
# if not_pose:
|
450 |
+
# print('zyhsb')
|
451 |
+
# print(data_paths)
|
452 |
+
# return -1
|
453 |
+
# return action_num
|
454 |
+
|
455 |
+
for obj in objs:
|
456 |
+
vertex_groups = obj.vertex_groups
|
457 |
+
# if not vertex_groups:
|
458 |
+
# continue
|
459 |
+
# for group in vertex_groups:
|
460 |
+
for vertex in obj.data.vertices:
|
461 |
+
vertex_info = {}
|
462 |
+
for group in vertex.groups:
|
463 |
+
name = vertex_groups[group.group].name
|
464 |
+
name = name.replace(" ", "_")
|
465 |
+
if True:
|
466 |
+
arm_modifier = [modifier for modifier in obj.modifiers if modifier.type == 'ARMATURE']
|
467 |
+
if len(arm_modifier) != 1:
|
468 |
+
print('zyhsb', len(arm_modifier))
|
469 |
+
return -2
|
470 |
+
# name = arm_modifier[0].object.name + "--" + name
|
471 |
+
return 1
|
472 |
+
|
473 |
+
# for obj in objs:
|
474 |
+
# if obj.data.shape_keys is not None:
|
475 |
+
# return 1
|
476 |
+
# # only 942!!!
|
477 |
+
# return 0
|
478 |
+
|
479 |
+
|
480 |
+
def delete(objs):
|
481 |
+
# check if the mesh object has skinning weight
|
482 |
+
for obj in objs:
|
483 |
+
vertex_groups = obj.vertex_groups
|
484 |
+
if not vertex_groups:
|
485 |
+
# delete the object
|
486 |
+
bpy.data.objects.remove(obj)
|
487 |
+
# print('delete!!!')
|
488 |
+
meshes = []
|
489 |
+
for obj in bpy.context.scene.objects:
|
490 |
+
if obj.type == "MESH":
|
491 |
+
meshes.append(obj)
|
492 |
+
|
493 |
+
return meshes
|
494 |
+
|
495 |
+
|
496 |
+
def merge_mesh(folder_path, export = None, save_join = True):
|
497 |
+
# output_path = os.path.join(folder_path, "rigging_norig.json")
|
498 |
+
# if os.path.exists(output_path):
|
499 |
+
# print("Already processed folder:", folder_path)
|
500 |
+
# return
|
501 |
+
bpy.ops.wm.read_homefile(use_empty=True)
|
502 |
+
try:
|
503 |
+
reload_json(folder_path)
|
504 |
+
except:
|
505 |
+
print("Error in reloading json file")
|
506 |
+
# remove the folder
|
507 |
+
os.system(f"rm -r {folder_path}")
|
508 |
+
return None, None
|
509 |
+
|
510 |
+
bpy.ops.object.select_all(action='DESELECT')
|
511 |
+
if export:
|
512 |
+
bpy.ops.wm.save_as_mainfile(filepath='reload_' + export)
|
513 |
+
|
514 |
+
meshes = []
|
515 |
+
for obj in bpy.context.scene.objects:
|
516 |
+
if obj.type == "MESH":
|
517 |
+
bpy.context.view_layer.objects.active = obj
|
518 |
+
obj.select_set(True)
|
519 |
+
meshes.append(obj)
|
520 |
+
print("meshes length", len(meshes))
|
521 |
+
|
522 |
+
bpy.ops.object.join()
|
523 |
+
if export:
|
524 |
+
bpy.ops.wm.save_as_mainfile(filepath='join_' + export)
|
525 |
+
|
526 |
+
meshes = []
|
527 |
+
for obj in bpy.context.scene.objects:
|
528 |
+
if obj.type == "MESH":
|
529 |
+
meshes.append(obj)
|
530 |
+
if len(meshes) != 1:
|
531 |
+
bpy.ops.wm.save_as_mainfile(filepath='join_f.blend')
|
532 |
+
assert len(meshes) == 1
|
533 |
+
# remesh(meshes[0])
|
534 |
+
|
535 |
+
|
536 |
+
if save_join:
|
537 |
+
obj_path = os.path.join(folder_path, "object.obj")
|
538 |
+
bpy.ops.wm.obj_export(filepath=obj_path, export_materials=False, export_uv=False, export_triangulated_mesh=True)
|
539 |
+
# mesh = trimesh.load(glb_file_path)
|
540 |
+
# mesh.export(obj_path, file_type='obj')
|
541 |
+
|
542 |
+
|
543 |
+
# save to json file
|
544 |
+
total_armature_count = 0
|
545 |
+
armature_obj = []
|
546 |
+
mesh_obj = []
|
547 |
+
for obj in bpy.context.scene.objects:
|
548 |
+
if obj.type == "ARMATURE":
|
549 |
+
total_armature_count += 1
|
550 |
+
armature_obj.append(obj)
|
551 |
+
if obj.type == "MESH":
|
552 |
+
mesh_obj.append(obj)
|
553 |
+
if total_armature_count == 0:
|
554 |
+
print("No rigging information for the file:", folder_path+"\n")
|
555 |
+
return None, None
|
556 |
+
|
557 |
+
|
558 |
+
######### delete bones that are not in the vertex group
|
559 |
+
vertex_group_name = [group.name for group in mesh_obj[0].vertex_groups]
|
560 |
+
bpy.context.view_layer.objects.active = armature_obj[0]
|
561 |
+
bpy.ops.object.mode_set(mode='EDIT')
|
562 |
+
edit_bones = armature_obj[0].data.edit_bones
|
563 |
+
bone_delete = set([bone.name for bone in edit_bones]) - set(vertex_group_name)
|
564 |
+
print(f"Deleting {len(bone_delete)} bones")
|
565 |
+
for bone in bone_delete:
|
566 |
+
# if the bone is root, then do not delete it
|
567 |
+
if edit_bones[bone].parent == None:
|
568 |
+
# return len([1 for child in edit_bones[bone].children if child.name in bone_delete])
|
569 |
+
num_children = len(edit_bones[bone].children)
|
570 |
+
if num_children <= 1:
|
571 |
+
edit_bones.remove(edit_bones[bone])
|
572 |
+
continue
|
573 |
+
if num_children > 1:
|
574 |
+
center = mathutils.Vector((0, 0, 0))
|
575 |
+
for child in edit_bones[bone].children:
|
576 |
+
center += child.head
|
577 |
+
center /= num_children
|
578 |
+
min_dist = 1e9
|
579 |
+
for child in edit_bones[bone].children:
|
580 |
+
dist = (child.head - center).length
|
581 |
+
if dist < min_dist:
|
582 |
+
min_dist = dist
|
583 |
+
min_child = child
|
584 |
+
for child in edit_bones[bone].children:
|
585 |
+
if child != min_child:
|
586 |
+
child.parent = min_child
|
587 |
+
edit_bones.remove(edit_bones[bone])
|
588 |
+
continue
|
589 |
+
continue
|
590 |
+
# assign bone's children to bone's parent
|
591 |
+
bone_obj = edit_bones[bone]
|
592 |
+
for child in bone_obj.children:
|
593 |
+
child.parent = bone_obj.parent
|
594 |
+
|
595 |
+
edit_bones.remove(edit_bones[bone])
|
596 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
597 |
+
|
598 |
+
if export:
|
599 |
+
bpy.ops.wm.save_as_mainfile(filepath='delete_' + export)
|
600 |
+
|
601 |
+
mesh_obj = []
|
602 |
+
armature_obj = []
|
603 |
+
for obj in bpy.context.scene.objects:
|
604 |
+
if obj.type == "MESH":
|
605 |
+
mesh_obj.append(obj)
|
606 |
+
if obj.type == "ARMATURE":
|
607 |
+
armature_obj.append(obj)
|
608 |
+
assert len(mesh_obj) == 1
|
609 |
+
assert len(armature_obj) == 1
|
610 |
+
|
611 |
+
return mesh_obj, armature_obj
|
612 |
+
|
613 |
+
|
614 |
+
def process(file_path, obj_path=None, stamp=None, tex=False):
|
615 |
+
# check if obj_path exists
|
616 |
+
# if os.path.exists(obj_path + '/object.obj'):
|
617 |
+
# print('object.obj exists')
|
618 |
+
# return True
|
619 |
+
reset_scene()
|
620 |
+
load_object(file_path)
|
621 |
+
# bpy.ops.import_scene.gltf(filepath=glb_file_path)
|
622 |
+
|
623 |
+
# delete hierarchy collections['glTF_not_exported']
|
624 |
+
if 'glTF_not_exported' in bpy.data.collections:
|
625 |
+
print('DELETE glTF_not_exported')
|
626 |
+
bpy.data.collections.remove(bpy.data.collections['glTF_not_exported'])
|
627 |
+
|
628 |
+
if stamp is not None:
|
629 |
+
# Set the current frame to the stamp value
|
630 |
+
bpy.context.scene.frame_set(stamp)
|
631 |
+
print(f'Set the current frame to {stamp}')
|
632 |
+
|
633 |
+
# Ensure all objects are updated to this frame
|
634 |
+
bpy.context.view_layer.update()
|
635 |
+
|
636 |
+
mesh_obj = []
|
637 |
+
armature_obj = []
|
638 |
+
for obj in bpy.context.scene.objects:
|
639 |
+
if obj.type == "ARMATURE":
|
640 |
+
# if len(armature_obj) > 0:
|
641 |
+
# print(file_path, 'has more than 1 armature')
|
642 |
+
# return -2
|
643 |
+
armature_obj.append(obj)
|
644 |
+
# obj.show_in_front = True
|
645 |
+
armature_obj[-1].data.pose_position = 'POSE'
|
646 |
+
if obj.type == "MESH":
|
647 |
+
mesh_obj.append(obj)
|
648 |
+
# if obj.data.shape_keys is not None:
|
649 |
+
# return False
|
650 |
+
|
651 |
+
# mesh_obj = delete(mesh_obj)
|
652 |
+
# if len(mesh_obj) == 0:
|
653 |
+
# # print('zyhsb -1', file_path, obj_path)
|
654 |
+
# return -1
|
655 |
+
# return check(mesh_obj, armature_obj)
|
656 |
+
|
657 |
+
|
658 |
+
# total_vertices = np.array([len(obj.data.vertices) for obj in mesh_obj]).sum()
|
659 |
+
# if total_vertices < 1000: return
|
660 |
+
# if total_vertices > 10000: remesh(mesh_obj)
|
661 |
+
|
662 |
+
|
663 |
+
# bpy.ops.object.select_all(action='DESELECT')
|
664 |
+
# armature_obj.select_set(True)
|
665 |
+
# execute(bpy.context)
|
666 |
+
|
667 |
+
|
668 |
+
# normalize(mesh_obj)
|
669 |
+
|
670 |
+
|
671 |
+
mesh_obj = delete(mesh_obj)
|
672 |
+
if len(mesh_obj) == 0:
|
673 |
+
# print('zyhsb -1', file_path, obj_path)
|
674 |
+
return -1
|
675 |
+
|
676 |
+
|
677 |
+
save_json(obj_path, mesh_obj, armature_obj, arm_name=True)
|
678 |
+
|
679 |
+
|
680 |
+
if not tex:
|
681 |
+
save_mesh(obj_path + '/object.obj')
|
682 |
+
else:
|
683 |
+
save_mesh(obj_path + '/object.obj', mtl=True, obj_path=obj_path)
|
684 |
+
|
685 |
+
|
686 |
+
mesh_obj, armature_obj = merge_mesh(obj_path)
|
687 |
+
if mesh_obj is None or armature_obj is None:
|
688 |
+
# print('zyhsb -2', file_path, obj_path)
|
689 |
+
return -2
|
690 |
+
|
691 |
+
|
692 |
+
try:
|
693 |
+
normalize(mesh_obj)
|
694 |
+
except:
|
695 |
+
os.system(f"rm -r {obj_path}")
|
696 |
+
# print('zyhsb -3', file_path, obj_path)
|
697 |
+
return -3
|
698 |
+
|
699 |
+
|
700 |
+
save_json(obj_path, mesh_obj, armature_obj)
|
701 |
+
|
702 |
+
if not tex:
|
703 |
+
save_mesh(obj_path + '/object.obj')
|
704 |
+
else:
|
705 |
+
save_mesh(obj_path + '/object.obj', mtl=True, obj_path=obj_path)
|
706 |
+
|
707 |
+
|
708 |
+
return 1
|
709 |
+
|
710 |
+
|
711 |
+
if __name__ == '__main__':
|
712 |
+
|
713 |
+
parser = argparse.ArgumentParser()
|
714 |
+
parser.add_argument(
|
715 |
+
"--object_path",
|
716 |
+
type=str,
|
717 |
+
required=True,
|
718 |
+
help="Path to the object file",
|
719 |
+
)
|
720 |
+
parser.add_argument(
|
721 |
+
"--output_dir",
|
722 |
+
type=str,
|
723 |
+
required=True,
|
724 |
+
help="Path to the directory where the rendered images and metadata will be saved.",
|
725 |
+
)
|
726 |
+
parser.add_argument(
|
727 |
+
"--stamp",
|
728 |
+
type=int,
|
729 |
+
required=False,
|
730 |
+
help="Stamp to be used for the rendering.",
|
731 |
+
)
|
732 |
+
parser.add_argument(
|
733 |
+
"--tex",
|
734 |
+
type=bool,
|
735 |
+
required=False,
|
736 |
+
help="Save the texture.",
|
737 |
+
)
|
738 |
+
argv = sys.argv[sys.argv.index("--") + 1 :]
|
739 |
+
args = parser.parse_args(argv)
|
740 |
+
|
741 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
742 |
+
stamp = args.stamp if args.stamp else None
|
743 |
+
print(f'Stamp: {stamp}')
|
744 |
+
result = process(args.object_path, obj_path=args.output_dir, stamp=stamp, tex=args.tex)
|
745 |
+
# import numpy as np
|
746 |
+
# os.makedirs(args.output_dir, exist_ok=True) # the directory may be removed
|
747 |
+
# np.save(args.output_dir + '/result.npy', np.array(result))
|
Anymate/checkpoints/.gitkeep
ADDED
File without changes
|
Anymate/configs/.gitkeep
ADDED
File without changes
|
Anymate/configs/conn.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 200
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: ce
|
11 |
+
mode: conn
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 16
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 10
|
21 |
+
|
22 |
+
optimizer:
|
23 |
+
weight_decay: 1.0e-05
|
24 |
+
lr: 0.0001
|
25 |
+
|
26 |
+
model:
|
27 |
+
decoder: attendjoints_con_combine
|
28 |
+
encoder: bert
|
29 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
30 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
31 |
+
load_encoder: ''
|
32 |
+
num_joints: 96
|
33 |
+
out_channels: 3
|
34 |
+
width: 768
|
35 |
+
heads: 12
|
36 |
+
init_scale: 0.25
|
37 |
+
flash: False
|
38 |
+
use_checkpoint: False
|
39 |
+
qkv_bias: False
|
40 |
+
separate: False
|
Anymate/configs/conn_token.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 200
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: ce
|
11 |
+
mode: conn
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 16
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 10
|
21 |
+
|
22 |
+
optimizer:
|
23 |
+
weight_decay: 1.0e-05
|
24 |
+
lr: 0.0001
|
25 |
+
|
26 |
+
model:
|
27 |
+
decoder: attendjoints_con_combine
|
28 |
+
encoder: bert
|
29 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
30 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
31 |
+
load_encoder: ''
|
32 |
+
num_joints: 96
|
33 |
+
out_channels: 3
|
34 |
+
width: 768
|
35 |
+
heads: 12
|
36 |
+
init_scale: 0.25
|
37 |
+
flash: False
|
38 |
+
use_checkpoint: False
|
39 |
+
qkv_bias: False
|
40 |
+
separate: False
|
Anymate/configs/diffusion.yaml
ADDED
@@ -0,0 +1,49 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 4000
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: chamfer
|
11 |
+
mode: diffusion
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 16
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 50
|
21 |
+
num_train_step: 100
|
22 |
+
num_training_points: 128
|
23 |
+
seed: 42
|
24 |
+
|
25 |
+
optimizer:
|
26 |
+
weight_decay: 1.0e-05
|
27 |
+
lr: 0.0001
|
28 |
+
|
29 |
+
model:
|
30 |
+
encoder: transformer
|
31 |
+
decoder: Cross_Attention_Diffusion
|
32 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
33 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
34 |
+
input_channels: 3
|
35 |
+
output_channels: 3
|
36 |
+
num_z: 16
|
37 |
+
num_x: 128
|
38 |
+
z_dim: 768
|
39 |
+
x_dim: 512
|
40 |
+
num_blocks: 4
|
41 |
+
num_compute_layers: 4
|
42 |
+
num_heads: 8
|
43 |
+
mlp_ratio: 4.0
|
44 |
+
qkv_bias: true
|
45 |
+
drop: 0.0
|
46 |
+
attn_drop: 0.0
|
47 |
+
drop_path: 0.0
|
48 |
+
num_latents: 16
|
49 |
+
use_projection: true
|
Anymate/configs/diffusion_concat.yaml
ADDED
@@ -0,0 +1,46 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 4000
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: chamfer
|
11 |
+
mode: diffusion
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 16
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 1000
|
21 |
+
num_train_step: 100
|
22 |
+
num_training_points: 128
|
23 |
+
seed: 42
|
24 |
+
|
25 |
+
optimizer:
|
26 |
+
weight_decay: 1.0e-05
|
27 |
+
lr: 0.0001
|
28 |
+
|
29 |
+
model:
|
30 |
+
encoder: bert
|
31 |
+
decoder: Pointe_Diffusion
|
32 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
33 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
34 |
+
input_channels: 3
|
35 |
+
output_channels: 3
|
36 |
+
n_ctx: 128
|
37 |
+
width: 768
|
38 |
+
layers: 12
|
39 |
+
heads: 8
|
40 |
+
init_scale: 0.25
|
41 |
+
time_token_cond: true
|
42 |
+
cond_drop_prob: 0.1
|
43 |
+
use_projection: true
|
44 |
+
|
45 |
+
|
46 |
+
|
Anymate/configs/diffusion_cross.yaml
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 4000
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: chamfer
|
11 |
+
mode: diffusion
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 32
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 1000
|
21 |
+
num_train_step: 100
|
22 |
+
num_training_points: 128
|
23 |
+
seed: 42
|
24 |
+
|
25 |
+
optimizer:
|
26 |
+
weight_decay: 1.0e-05
|
27 |
+
lr: 0.0001
|
28 |
+
|
29 |
+
model:
|
30 |
+
encoder: miche
|
31 |
+
decoder: Cross_Attention_Diffusion
|
32 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
33 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
34 |
+
input_channels: 3
|
35 |
+
output_channels: 3
|
36 |
+
num_z: 16
|
37 |
+
num_x: 128
|
38 |
+
z_dim: 768
|
39 |
+
x_dim: 512
|
40 |
+
num_blocks: 4
|
41 |
+
num_compute_layers: 4
|
42 |
+
num_heads: 8
|
43 |
+
mlp_ratio: 4.0
|
44 |
+
qkv_bias: true
|
45 |
+
drop: 0.0
|
46 |
+
attn_drop: 0.0
|
47 |
+
drop_path: 0.0
|
48 |
+
num_latents: 16
|
49 |
+
use_projection: true
|
50 |
+
|
51 |
+
|
Anymate/configs/joints.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 200
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: chamfer
|
11 |
+
mode: joints
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 16
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 10
|
21 |
+
|
22 |
+
optimizer:
|
23 |
+
weight_decay: 1.0e-05
|
24 |
+
lr: 0.0001
|
25 |
+
|
26 |
+
model:
|
27 |
+
decoder: transformer_latent
|
28 |
+
encoder: bert
|
29 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
30 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
31 |
+
load_encoder: ''
|
32 |
+
num_joints: 96
|
33 |
+
out_channels: 3
|
34 |
+
width: 768
|
35 |
+
heads: 12
|
36 |
+
init_scale: 0.25
|
37 |
+
flash: False
|
38 |
+
use_checkpoint: False
|
39 |
+
qkv_bias: False
|
40 |
+
separate: False
|
Anymate/configs/joints_implicit.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 200
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: chamfer
|
11 |
+
mode: joints
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 8
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 10
|
21 |
+
|
22 |
+
optimizer:
|
23 |
+
weight_decay: 1.0e-05
|
24 |
+
lr: 0.0001
|
25 |
+
|
26 |
+
model:
|
27 |
+
decoder: implicit_transformer
|
28 |
+
encoder: bert
|
29 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
30 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
31 |
+
load_encoder: ''
|
32 |
+
num_joints: 96
|
33 |
+
out_channels: 3
|
34 |
+
width: 768
|
35 |
+
heads: 12
|
36 |
+
init_scale: 0.25
|
37 |
+
flash: False
|
38 |
+
use_checkpoint: False
|
39 |
+
qkv_bias: False
|
40 |
+
separate: False
|
Anymate/configs/joints_triplane.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 200
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: chamfer
|
11 |
+
mode: joints
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 16
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 10
|
21 |
+
|
22 |
+
optimizer:
|
23 |
+
weight_decay: 1.0e-05
|
24 |
+
lr: 0.0001
|
25 |
+
|
26 |
+
model:
|
27 |
+
decoder: triplane
|
28 |
+
encoder: bert
|
29 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
30 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
31 |
+
load_encoder: ''
|
32 |
+
num_joints: 96
|
33 |
+
out_channels: 3
|
34 |
+
width: 768
|
35 |
+
heads: 12
|
36 |
+
init_scale: 0.25
|
37 |
+
flash: False
|
38 |
+
use_checkpoint: False
|
39 |
+
qkv_bias: False
|
40 |
+
separate: False
|
Anymate/configs/skin.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 200
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: cos_clamp
|
11 |
+
mode: skin
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 16
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 10
|
21 |
+
|
22 |
+
optimizer:
|
23 |
+
weight_decay: 1.0e-05
|
24 |
+
lr: 0.0001
|
25 |
+
|
26 |
+
model:
|
27 |
+
decoder: attendjoints_combine
|
28 |
+
encoder: bert
|
29 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
30 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
31 |
+
load_encoder: ''
|
32 |
+
num_joints: 96
|
33 |
+
out_channels: 3
|
34 |
+
width: 768
|
35 |
+
heads: 12
|
36 |
+
init_scale: 0.25
|
37 |
+
flash: False
|
38 |
+
use_checkpoint: False
|
39 |
+
qkv_bias: False
|
40 |
+
separate: False
|
Anymate/configs/skin_multi.yaml
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
args:
|
2 |
+
aggr: max
|
3 |
+
checkpoint: Anymate/checkpoints
|
4 |
+
device: cuda
|
5 |
+
epochs: 200
|
6 |
+
finetune: true
|
7 |
+
gamma: 0.2
|
8 |
+
input_normal: false
|
9 |
+
logdir: Anymate/logs
|
10 |
+
loss: cos_clamp
|
11 |
+
mode: skin
|
12 |
+
resume: ''
|
13 |
+
root: Anymate/data
|
14 |
+
schedule: []
|
15 |
+
start_epoch: 0
|
16 |
+
test_batch: 1
|
17 |
+
testset: Anymate_test
|
18 |
+
train_batch: 4
|
19 |
+
trainset: Anymate_train
|
20 |
+
test_freq: 10
|
21 |
+
|
22 |
+
optimizer:
|
23 |
+
weight_decay: 1.0e-05
|
24 |
+
lr: 0.0001
|
25 |
+
|
26 |
+
model:
|
27 |
+
decoder: attendjoints_multi
|
28 |
+
encoder: bert
|
29 |
+
config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
|
30 |
+
ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
|
31 |
+
load_encoder: ''
|
32 |
+
num_joints: 96
|
33 |
+
out_channels: 3
|
34 |
+
width: 768
|
35 |
+
heads: 12
|
36 |
+
init_scale: 0.25
|
37 |
+
flash: False
|
38 |
+
use_checkpoint: False
|
39 |
+
qkv_bias: False
|
40 |
+
separate: False
|
Anymate/dataset.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from torch.utils.data import Dataset
|
3 |
+
import os
|
4 |
+
import numpy as np
|
5 |
+
from Anymate.utils.dataset_utils import create_mask, index_to_sparse, index_to_sparse_con
|
6 |
+
|
7 |
+
def my_collate(batch):
|
8 |
+
# print(len(batch))
|
9 |
+
data = {}
|
10 |
+
for key in batch[0]:
|
11 |
+
if key=='vox' or key=='name' or key=='joints_num' or key=='skins_index' or key=='skins_weight' or key=='parent_index' or key=='conns' or key=='joints' or key=='bones' or key=='mesh_skins_index' or key=='mesh_skins_weight' or key=='mesh_pc' or key=='mesh_face':
|
12 |
+
data[key] = [sample[key] for sample in batch]
|
13 |
+
elif key=='pc':
|
14 |
+
data['points_cloud'] = torch.stack([sample['pc'] for sample in batch])
|
15 |
+
elif key=='skins':
|
16 |
+
continue
|
17 |
+
elif key=='bones_num':
|
18 |
+
data[key] = torch.tensor([sample['bones_num'] for sample in batch])
|
19 |
+
else:
|
20 |
+
data[key] = torch.stack([sample[key] for sample in batch])
|
21 |
+
|
22 |
+
if 'skins_index' in batch[0]:
|
23 |
+
max_joints = max(data['joints_num'])
|
24 |
+
max_bones = max(data['bones_num'])
|
25 |
+
# max_joints = 64
|
26 |
+
skin_list = [index_to_sparse(data['skins_index'][i].unsqueeze(0), data['skins_weight'][i].unsqueeze(0), [1, 8192, max_bones])[0] for i in range(len(data['skins_index']))]
|
27 |
+
data['skins'] = torch.stack(skin_list,dim=0)
|
28 |
+
data['joints_mask'] = torch.stack([create_mask(sample['joints_num'],max_len=max_joints) for sample in batch])
|
29 |
+
data['bones_mask'] = torch.stack([create_mask(sample['bones_num'],max_len=max_bones) for sample in batch])
|
30 |
+
|
31 |
+
if 'conns' in batch[0]:
|
32 |
+
max_joints = max(data['joints_num'])
|
33 |
+
conn_matrix = torch.zeros(len(data['conns']), 96, max_joints)
|
34 |
+
for i in range(len(data['conns'])):
|
35 |
+
for j in range(data['joints_num'][i]):
|
36 |
+
conn_matrix[i, j, data['conns'][i][j].long()] = 1
|
37 |
+
data['conns'] = conn_matrix
|
38 |
+
if 'joints' in batch[0]:
|
39 |
+
padded_joints_matrix = torch.ones(len(data['name']), 96, 3) * (-3)
|
40 |
+
for i in range(len(data['name'])):
|
41 |
+
padded_joints_matrix[i, :data['joints_num'][i], :] = data['joints'][i]
|
42 |
+
data['joints'] = padded_joints_matrix
|
43 |
+
if 'bones' in batch[0]:
|
44 |
+
padded_bones_matrix = torch.ones(len(data['name']), 64, 6) * (-3)
|
45 |
+
for i in range(len(data['name'])):
|
46 |
+
padded_bones_matrix[i, :data['bones_num'][i], :] = data['bones'][i]
|
47 |
+
data['bones'] = padded_bones_matrix
|
48 |
+
return data
|
49 |
+
|
50 |
+
class AnymateDataset(Dataset):
|
51 |
+
def __init__(self, name='Anymate_test', root='Anymate/data'):
|
52 |
+
|
53 |
+
if os.path.exists(os.path.join(root, name) + '.pt'):
|
54 |
+
self.data_list = torch.load(os.path.join(root, name) + '.pt')
|
55 |
+
else:
|
56 |
+
raise ValueError('Dataset not found at path: {}'.format(os.path.join(root, name) + '.pt'))
|
57 |
+
|
58 |
+
def __len__(self):
|
59 |
+
return len(self.data_list)
|
60 |
+
|
61 |
+
def __getitem__(self, idx):
|
62 |
+
return self.data_list[idx]
|
Anymate/get_checkpoints.sh
ADDED
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cd Anymate/checkpoints
|
2 |
+
mkdir joint
|
3 |
+
cd joint
|
4 |
+
|
5 |
+
echo "Downloading joint checkpoints..."
|
6 |
+
wget "https://huggingface.co/yfdeng/Anymate/resolve/main/checkpoints/joint/bert-transformer_latent-train-8gpu-finetune.pth.tar?download=true" -O bert-transformer_latent-train-8gpu-finetune.pth.tar
|
7 |
+
|
8 |
+
cd ..
|
9 |
+
mkdir conn
|
10 |
+
cd conn
|
11 |
+
|
12 |
+
echo "Downloading conn checkpoints..."
|
13 |
+
wget "https://huggingface.co/yfdeng/Anymate/resolve/main/checkpoints/conn/bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar?download=true" -O bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar
|
14 |
+
|
15 |
+
cd ..
|
16 |
+
mkdir skin
|
17 |
+
cd skin
|
18 |
+
|
19 |
+
echo "Downloading skin checkpoints..."
|
20 |
+
wget "https://huggingface.co/yfdeng/Anymate/resolve/main/checkpoints/skin/bert-attendjoints_combine-train-8gpu-finetune.pth.tar?download=true" -O bert-attendjoints_combine-train-8gpu-finetune.pth.tar
|
21 |
+
|
22 |
+
echo "Finished downloading checkpoints!"
|
Anymate/get_datasets.sh
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
cd Anymate/data
|
2 |
+
wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_test.pt?download=true" -O Anymate_test.pt
|
3 |
+
wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_0.pt?download=true" -O Anymate_train_0.pt
|
4 |
+
wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_1.pt?download=true" -O Anymate_train_1.pt
|
5 |
+
wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_2.pt?download=true" -O Anymate_train_2.pt
|
6 |
+
wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_3.pt?download=true" -O Anymate_train_3.pt
|
7 |
+
wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_4.pt?download=true" -O Anymate_train_4.pt
|
8 |
+
wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_5.pt?download=true" -O Anymate_train_5.pt
|
9 |
+
wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_6.pt?download=true" -O Anymate_train_6.pt
|
10 |
+
wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_7.pt?download=true" -O Anymate_train_7.pt
|
11 |
+
|
12 |
+
echo "Finished downloading datasets!"
|
Anymate/model.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from ThirdParty.michelangelo.utils.misc import get_config_from_file, instantiate_from_config
|
4 |
+
# from ThirdParty.PointLLM.pointllm.model.pointllm import PointLLMLlamaForCausalLM
|
5 |
+
from ThirdParty.michelangelo.models.modules.distributions import DiagonalGaussianDistribution
|
6 |
+
from ThirdParty.michelangelo.models.modules.embedder import components_from_spherical_harmonics
|
7 |
+
from Anymate.utils.diffusion_encoder import TransformerEncoder
|
8 |
+
from Anymate.models.joint import TransformerDecoder, ImplicitTransformerDecoder, TriPlaneDecoder
|
9 |
+
from Anymate.models.conn import AttendjointsDecoder_con_combine, AttendjointsDecoder_con_token
|
10 |
+
from Anymate.models.skin import AttendjointsDecoder_combine, AttendjointsDecoder_multi
|
11 |
+
from Anymate.models.diffusion import Pointe_Diffusion, Cross_Attention_Diffusion
|
12 |
+
|
13 |
+
class Encoder(nn.Module):
|
14 |
+
def __init__(self,
|
15 |
+
only_embed = True,
|
16 |
+
config_path = './ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml',
|
17 |
+
ckpt_path = './ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt',
|
18 |
+
num_latents = 257,
|
19 |
+
device = 'cuda'):
|
20 |
+
|
21 |
+
super().__init__()
|
22 |
+
|
23 |
+
model_config = get_config_from_file(config_path)
|
24 |
+
if hasattr(model_config, "model"):
|
25 |
+
model_config = model_config.model
|
26 |
+
|
27 |
+
if ckpt_path is not None:
|
28 |
+
model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
|
29 |
+
else:
|
30 |
+
model = instantiate_from_config(model_config)
|
31 |
+
model.model.shape_model.encoder.num_latents = num_latents
|
32 |
+
model.model.shape_model.encoder.query = nn.Parameter(torch.randn((num_latents, 768), device=device, dtype=torch.float32) * 0.02)
|
33 |
+
|
34 |
+
self.shape_projection = model.model.shape_projection
|
35 |
+
self.encoder = model.model.shape_model.encoder
|
36 |
+
self.normal_embedder = components_from_spherical_harmonics
|
37 |
+
old_linear_proj = self.encoder.input_proj
|
38 |
+
self.encoder.input_proj = nn.Linear(old_linear_proj.in_features + 25, old_linear_proj.out_features)
|
39 |
+
self.encoder.input_proj.weight.data[:, :old_linear_proj.in_features] = old_linear_proj.weight.data[:, :old_linear_proj.in_features].clone()
|
40 |
+
self.encoder.input_proj.bias.data = old_linear_proj.bias.data.clone()
|
41 |
+
if not only_embed:
|
42 |
+
self.embed_dim = model.model.shape_model.embed_dim
|
43 |
+
self.pre_kl = model.model.shape_model.pre_kl
|
44 |
+
self.post_kl = model.model.shape_model.post_kl
|
45 |
+
self.transformer = model.model.shape_model.transformer
|
46 |
+
|
47 |
+
|
48 |
+
def encode_latents(self,
|
49 |
+
pc: torch.FloatTensor,
|
50 |
+
feats = None):
|
51 |
+
|
52 |
+
feats_embed = self.normal_embedder(feats)
|
53 |
+
feats = torch.cat([feats, feats_embed], dim=-1)
|
54 |
+
|
55 |
+
x, _ = self.encoder(pc, feats)
|
56 |
+
|
57 |
+
shape_embed = x[:, 0]
|
58 |
+
latents = x[:, 1:]
|
59 |
+
|
60 |
+
return shape_embed, latents
|
61 |
+
|
62 |
+
|
63 |
+
def encode_shape_embed(self, surface, return_latents: bool = False):
|
64 |
+
"""
|
65 |
+
|
66 |
+
Args:
|
67 |
+
surface (torch.FloatTensor): [bs, n, 3 + c]
|
68 |
+
return_latents (bool):
|
69 |
+
|
70 |
+
Returns:
|
71 |
+
x (torch.FloatTensor): [bs, projection_dim]
|
72 |
+
shape_latents (torch.FloatTensor): [bs, m, d]
|
73 |
+
"""
|
74 |
+
|
75 |
+
pc = surface[..., 0:3]
|
76 |
+
feats = surface[..., 3:]
|
77 |
+
|
78 |
+
shape_embed, shape_latents = self.encode_latents(pc, feats)
|
79 |
+
x = shape_embed @ self.shape_projection
|
80 |
+
|
81 |
+
if return_latents:
|
82 |
+
return x, shape_latents
|
83 |
+
else:
|
84 |
+
return x
|
85 |
+
|
86 |
+
|
87 |
+
def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
|
88 |
+
posterior = None
|
89 |
+
if self.embed_dim > 0:
|
90 |
+
moments = self.pre_kl(latents)
|
91 |
+
posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
|
92 |
+
|
93 |
+
if sample_posterior:
|
94 |
+
kl_embed = posterior.sample()
|
95 |
+
else:
|
96 |
+
kl_embed = posterior.mode()
|
97 |
+
else:
|
98 |
+
kl_embed = latents
|
99 |
+
|
100 |
+
return kl_embed, posterior
|
101 |
+
|
102 |
+
|
103 |
+
def decode(self, latents: torch.FloatTensor):
|
104 |
+
latents = self.post_kl(latents)
|
105 |
+
return self.transformer(latents)
|
106 |
+
|
107 |
+
|
108 |
+
class EncoderDecoder(nn.Module):
|
109 |
+
def __init__(self,
|
110 |
+
decoder = 'mlp',
|
111 |
+
encoder = 'miche',
|
112 |
+
config_path = './ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml',
|
113 |
+
ckpt_path = './ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt',
|
114 |
+
load_encoder = '',
|
115 |
+
num_joints = 96,
|
116 |
+
out_channels = 3,
|
117 |
+
width = 768,
|
118 |
+
device = 'cuda',
|
119 |
+
dtype = torch.float32,
|
120 |
+
heads = 12,
|
121 |
+
init_scale: float = 0.25,
|
122 |
+
flash = False,
|
123 |
+
use_checkpoint = False,
|
124 |
+
qkv_bias = False,
|
125 |
+
separate = False,
|
126 |
+
**kwargs):
|
127 |
+
|
128 |
+
super().__init__()
|
129 |
+
self.decoder_name = decoder
|
130 |
+
self.encoder_name = encoder
|
131 |
+
self.dtype = dtype
|
132 |
+
self.load_encoder = load_encoder
|
133 |
+
|
134 |
+
if decoder == 'transformer_latent':
|
135 |
+
self.only_embed = False
|
136 |
+
self.return_latents = True
|
137 |
+
self.decoder = TransformerDecoder(
|
138 |
+
num_latents = num_joints,
|
139 |
+
out_channels = out_channels,
|
140 |
+
width = width,
|
141 |
+
device = device,
|
142 |
+
dtype = dtype,
|
143 |
+
heads = heads,
|
144 |
+
init_scale = init_scale,
|
145 |
+
flash = flash,
|
146 |
+
use_checkpoint = use_checkpoint,
|
147 |
+
qkv_bias = qkv_bias
|
148 |
+
)
|
149 |
+
elif decoder == 'implicit_transformer':
|
150 |
+
self.only_embed = False
|
151 |
+
self.return_latents = True
|
152 |
+
self.decoder = ImplicitTransformerDecoder(
|
153 |
+
device = device,
|
154 |
+
dtype = dtype,
|
155 |
+
num_latents = 257,
|
156 |
+
out_channels = 1,
|
157 |
+
width = width,
|
158 |
+
heads = heads,
|
159 |
+
init_scale = init_scale,
|
160 |
+
flash = flash,
|
161 |
+
use_checkpoint = use_checkpoint,
|
162 |
+
qkv_bias = qkv_bias
|
163 |
+
)
|
164 |
+
elif decoder == 'triplane': #consider add these parameters to config
|
165 |
+
self.only_embed = True
|
166 |
+
self.return_latents = False
|
167 |
+
self.decoder = TriPlaneDecoder(
|
168 |
+
z_dim = 768,
|
169 |
+
c_dim = 0,
|
170 |
+
w_dim = 768,
|
171 |
+
mapping_kwargs = {'num_layers': 2},
|
172 |
+
synthesis_kwargs = {'num_fp16_res': 0, 'conv_clamp': None, 'fused_modconv_default': 'inference_only'}
|
173 |
+
)
|
174 |
+
|
175 |
+
elif decoder == 'Pointe_Diffusion':
|
176 |
+
self.only_embed = False
|
177 |
+
self.return_latents = True
|
178 |
+
self.decoder = Pointe_Diffusion(**kwargs)
|
179 |
+
|
180 |
+
elif decoder == 'Cross_Attention_Diffusion':
|
181 |
+
self.only_embed = False
|
182 |
+
self.return_latents = True
|
183 |
+
self.decoder = Cross_Attention_Diffusion(**kwargs)
|
184 |
+
|
185 |
+
elif decoder == 'attendjoints_combine':
|
186 |
+
self.only_embed = False
|
187 |
+
self.return_latents = True
|
188 |
+
self.decoder = AttendjointsDecoder_combine(
|
189 |
+
width = width,
|
190 |
+
device = device,
|
191 |
+
dtype = dtype,
|
192 |
+
heads = heads,
|
193 |
+
init_scale = init_scale,
|
194 |
+
flash = flash,
|
195 |
+
use_checkpoint = use_checkpoint,
|
196 |
+
separate = separate,
|
197 |
+
qkv_bias = qkv_bias
|
198 |
+
)
|
199 |
+
elif decoder == 'attendjoints_multi':
|
200 |
+
self.only_embed = False
|
201 |
+
self.return_latents = True
|
202 |
+
self.decoder = AttendjointsDecoder_multi(
|
203 |
+
width = width,
|
204 |
+
device = device,
|
205 |
+
dtype = dtype,
|
206 |
+
heads = heads,
|
207 |
+
init_scale = init_scale,
|
208 |
+
flash = flash,
|
209 |
+
use_checkpoint = use_checkpoint,
|
210 |
+
qkv_bias = qkv_bias,
|
211 |
+
separate=separate
|
212 |
+
)
|
213 |
+
elif decoder == 'attendjoints_con_combine':
|
214 |
+
self.only_embed = False
|
215 |
+
self.return_latents = True
|
216 |
+
self.decoder = AttendjointsDecoder_con_combine(
|
217 |
+
width = width,
|
218 |
+
device = device,
|
219 |
+
dtype = dtype,
|
220 |
+
heads = heads,
|
221 |
+
init_scale = init_scale,
|
222 |
+
flash = flash,
|
223 |
+
use_checkpoint = use_checkpoint,
|
224 |
+
qkv_bias = qkv_bias
|
225 |
+
)
|
226 |
+
elif decoder == 'attendjoints_con_token':
|
227 |
+
self.only_embed = False
|
228 |
+
self.return_latents = True
|
229 |
+
self.decoder = AttendjointsDecoder_con_token(
|
230 |
+
width = width,
|
231 |
+
device = device,
|
232 |
+
dtype = dtype,
|
233 |
+
heads = heads,
|
234 |
+
init_scale = init_scale,
|
235 |
+
flash = flash,
|
236 |
+
use_checkpoint = use_checkpoint,
|
237 |
+
qkv_bias = qkv_bias,
|
238 |
+
separate = separate
|
239 |
+
)
|
240 |
+
|
241 |
+
if encoder == 'miche':
|
242 |
+
if not self.load_encoder:
|
243 |
+
self.encoder = Encoder(only_embed=self.only_embed, config_path=config_path, ckpt_path=ckpt_path, device=device)
|
244 |
+
else:
|
245 |
+
self.encoder = Encoder(only_embed=self.only_embed, config_path=config_path, ckpt_path=None, device=device)
|
246 |
+
try:
|
247 |
+
print("=> loading encoder checkpoint '{}'".format(self.load_encoder))
|
248 |
+
checkpoint = torch.load(self.load_encoder, map_location='cpu')
|
249 |
+
state_dict = {k[8:]: v for k, v in checkpoint['state_dict'].items() if k.startswith('encoder')}
|
250 |
+
self.encoder.load_state_dict(state_dict)
|
251 |
+
print("=> loaded encoder checkpoint '{}'".format(self.load_encoder))
|
252 |
+
except:
|
253 |
+
print("=> no encoder checkpoint found at '{}'".format(self.load_encoder))
|
254 |
+
if self.load_encoder:
|
255 |
+
self.point_proj = nn.Sequential(
|
256 |
+
nn.Linear(768, 768, dtype=dtype),
|
257 |
+
nn.GELU(),
|
258 |
+
nn.Linear(768, 768, dtype=dtype),
|
259 |
+
)
|
260 |
+
|
261 |
+
if encoder == 'bert':
|
262 |
+
# model_name = 'RunsenXu/PointLLM_7B_v1.2'
|
263 |
+
# model = PointLLMLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=True, torch_dtype=dtype)
|
264 |
+
# self.encoder = model.model.point_backbone.to(device)
|
265 |
+
# model = None
|
266 |
+
from ThirdParty.PointLLM.pointllm.model import PointTransformer
|
267 |
+
from ThirdParty.PointLLM.pointllm.utils import cfg_from_yaml_file
|
268 |
+
import os
|
269 |
+
# address of config file, in the same dir of this file
|
270 |
+
point_bert_config_name = "PointTransformer_8192point_2layer" # * default for v1.2, v1.1 uses PointTransformer_base_8192point.yaml
|
271 |
+
point_bert_config_addr = os.path.join("./ThirdParty/PointLLM/pointllm/model/pointbert/PointTransformer_8192point_2layer.yaml")
|
272 |
+
print(f"Loading PointBERT config from {point_bert_config_addr}.")
|
273 |
+
point_bert_config = cfg_from_yaml_file(point_bert_config_addr)
|
274 |
+
point_bert_config.model.point_dims = 6
|
275 |
+
use_max_pool = getattr(point_bert_config.model, "use_max_pool", False) # * default is false
|
276 |
+
|
277 |
+
self.encoder = PointTransformer(point_bert_config.model, use_max_pool=use_max_pool).to(device)
|
278 |
+
if self.return_latents:
|
279 |
+
self.point_proj = nn.Sequential(
|
280 |
+
nn.Linear(384, 512, dtype=dtype),
|
281 |
+
nn.GELU(),
|
282 |
+
nn.Linear(512, 512, dtype=dtype),
|
283 |
+
nn.GELU(),
|
284 |
+
nn.Linear(512, 768, dtype=dtype)
|
285 |
+
)
|
286 |
+
else:
|
287 |
+
self.point_proj = nn.ModuleList([
|
288 |
+
nn.Sequential(
|
289 |
+
nn.Linear(384, 512, dtype=dtype),
|
290 |
+
nn.GELU(),
|
291 |
+
nn.Linear(512, 512, dtype=dtype),
|
292 |
+
nn.GELU(),
|
293 |
+
nn.Linear(512, 768, dtype=dtype)
|
294 |
+
),
|
295 |
+
nn.Linear(513, 1, dtype=dtype)
|
296 |
+
])
|
297 |
+
if encoder == 'transformer':
|
298 |
+
self.points_cloud_embed = nn.Linear(
|
299 |
+
768, 768, device=device, dtype=dtype
|
300 |
+
)
|
301 |
+
self.encoder = TransformerEncoder(device=device,dtype=dtype, num_latents=kwargs['num_latents'])
|
302 |
+
|
303 |
+
|
304 |
+
|
305 |
+
def encode(self, data, device='cuda'):
|
306 |
+
assert self.encoder_name in ['miche', 'bert', 'transformer'], f'Encoder {self.encoder_name} not supported'
|
307 |
+
if self.encoder_name == 'miche':
|
308 |
+
surface = data['points_cloud'].to(self.dtype).to(device)
|
309 |
+
|
310 |
+
# encoding
|
311 |
+
shape_embed, shape_latents = self.encoder.encode_shape_embed(surface, return_latents=True) # ShapeAsLatentPerceiver.encode_latents(): encoder
|
312 |
+
|
313 |
+
if self.only_embed:
|
314 |
+
if self.return_latents:
|
315 |
+
if self.load_encoder:
|
316 |
+
return self.point_proj(torch.cat([shape_embed.unsqueeze(1), shape_latents], dim=1))
|
317 |
+
return torch.cat([shape_embed.unsqueeze(1), shape_latents], dim=1) # torch.Size([bs, 257, 768]
|
318 |
+
return shape_embed # shape_embed: torch.Size([bs, 768])
|
319 |
+
|
320 |
+
shape_zq, posterior = self.encoder.encode_kl_embed(shape_latents) # ShapeAsLatentPerceiver.encode_kl_embed(): pre_kl + DiagonalGaussianDistribution()
|
321 |
+
# shape_zq, posterior = self.encoder.encode_kl_embed(shape_latents, sample_posterior=False) # not sample
|
322 |
+
# pretrained weight has 0 +- 0.7 mean and 0.5 +- 0.5 std
|
323 |
+
# trained weight has 0 +- 1.8 mean and 0.1 +- 0.1 std
|
324 |
+
# generally okay
|
325 |
+
|
326 |
+
latents = self.encoder.decode(shape_zq) # ShapeAsLatentPerceiver.decode(): post_kl + transformer
|
327 |
+
|
328 |
+
if not self.return_latents:
|
329 |
+
latents = torch.cat([shape_latents, latents], dim=1) # torch.Size([bs, 512, 768])
|
330 |
+
|
331 |
+
if self.load_encoder:
|
332 |
+
return self.point_proj(torch.cat([shape_embed.unsqueeze(1), latents], dim=1))
|
333 |
+
return torch.cat([shape_embed.unsqueeze(1), latents], dim=1) # torch.Size([bs, 257 / 513, 768])
|
334 |
+
|
335 |
+
if self.encoder_name == 'bert':
|
336 |
+
points = data['points_cloud'].to(self.dtype).to(device)
|
337 |
+
points = points[:, :, :3] / 2
|
338 |
+
points = torch.cat([points, torch.zeros_like(points)], dim=-1)
|
339 |
+
points = self.encoder(points)
|
340 |
+
|
341 |
+
if self.return_latents:
|
342 |
+
points = self.point_proj(points)
|
343 |
+
else:
|
344 |
+
points = self.point_proj[0](points)
|
345 |
+
points = self.point_proj[1](points.permute(0, 2, 1)).squeeze(-1)
|
346 |
+
return points
|
347 |
+
|
348 |
+
if self.encoder_name == 'transformer':
|
349 |
+
points = data['points_cloud'].to(self.dtype).to(device)
|
350 |
+
cond = self.encoder.encode_pc(points)
|
351 |
+
cond = self.points_cloud_embed(cond)
|
352 |
+
return cond
|
353 |
+
|
354 |
+
def forward(self, data, device='cuda', downsample=False, **kwargs):
|
355 |
+
latents = self.encode(data, device)
|
356 |
+
# print('latents shape', latents.shape)
|
357 |
+
|
358 |
+
logits = self.decoder(latents, data, device=device, downsample=downsample,**kwargs)
|
359 |
+
|
360 |
+
return logits
|
Anymate/models/__init__.py
ADDED
File without changes
|
Anymate/models/conn.py
ADDED
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock, ResidualAttentionBlock, Transformer
|
4 |
+
from ThirdParty.michelangelo.models.modules.embedder import FourierEmbedder, components_from_spherical_harmonics
|
5 |
+
|
6 |
+
class AttendjointsDecoder_con_combine(nn.Module):
|
7 |
+
def __init__(self,
|
8 |
+
width = 768,
|
9 |
+
layers = 2,
|
10 |
+
device = 'cuda',
|
11 |
+
dtype = torch.float32,
|
12 |
+
heads = 12,
|
13 |
+
init_scale: float = 0.25,
|
14 |
+
flash = False,
|
15 |
+
use_checkpoint = False,
|
16 |
+
qkv_bias = False,
|
17 |
+
num_freqs: int = 8,
|
18 |
+
include_pi: bool = True,
|
19 |
+
separate = False,
|
20 |
+
use_mask = True):
|
21 |
+
|
22 |
+
super().__init__()
|
23 |
+
|
24 |
+
self.use_checkpoint = use_checkpoint
|
25 |
+
self.separate = separate
|
26 |
+
self.use_mask = use_mask
|
27 |
+
# self.num_latents = num_latents
|
28 |
+
|
29 |
+
# self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
|
30 |
+
|
31 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
32 |
+
self.co_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
|
33 |
+
|
34 |
+
# self.proj_attn = nn.Linear(width, width, device=device, dtype=dtype)
|
35 |
+
|
36 |
+
self.cross_attn = nn.ModuleList([ResidualCrossAttentionBlock(
|
37 |
+
device=device,
|
38 |
+
dtype=dtype,
|
39 |
+
width=width,
|
40 |
+
heads=heads,
|
41 |
+
init_scale=init_scale,
|
42 |
+
qkv_bias=qkv_bias,
|
43 |
+
flash=flash,
|
44 |
+
) for _ in range(layers)])
|
45 |
+
|
46 |
+
self.self_attn = nn.ModuleList([ResidualAttentionBlock(
|
47 |
+
device=device,
|
48 |
+
dtype=dtype,
|
49 |
+
n_ctx=-1,
|
50 |
+
width=width,
|
51 |
+
heads=heads,
|
52 |
+
init_scale=init_scale,
|
53 |
+
qkv_bias=qkv_bias,
|
54 |
+
flash=flash,
|
55 |
+
) for _ in range(layers * 2)])
|
56 |
+
|
57 |
+
# self.joint_embed_proj = nn.ModuleList([nn.Linear(width, width, device=device, dtype=dtype) for _ in range(layers)])
|
58 |
+
|
59 |
+
|
60 |
+
self.q_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
61 |
+
self.k_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
62 |
+
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
63 |
+
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
|
64 |
+
|
65 |
+
# self.last_cross_attn = ResidualCrossAttentionBlock(
|
66 |
+
# device=device,
|
67 |
+
# dtype=dtype,
|
68 |
+
# width=width,
|
69 |
+
# heads=heads,
|
70 |
+
# init_scale=init_scale,
|
71 |
+
# qkv_bias=qkv_bias,
|
72 |
+
# flash=flash,
|
73 |
+
# )
|
74 |
+
# self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
|
75 |
+
# self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
|
76 |
+
|
77 |
+
def forward(self, latents, data=None, device='cuda', downsample=None, dtype=torch.float32):
|
78 |
+
|
79 |
+
joints = data['joints'].to(device)
|
80 |
+
max_joints = max(data['joints_num'])
|
81 |
+
joints = joints[:, :max_joints, :3]
|
82 |
+
|
83 |
+
joints_embeds = self.fourier_embedder(joints)
|
84 |
+
joints_embeds = self.co_proj(joints_embeds)
|
85 |
+
|
86 |
+
joints_num = joints_embeds.shape[-2]
|
87 |
+
|
88 |
+
x = [joints_embeds, joints_embeds.clone()]
|
89 |
+
|
90 |
+
for i in range(2):
|
91 |
+
for j, layer in enumerate(self.cross_attn):
|
92 |
+
|
93 |
+
x[i] = layer(x[i], latents)
|
94 |
+
|
95 |
+
if self.use_mask:
|
96 |
+
x[i] = self.self_attn[2*i+j](x[i], mask=data['joints_mask'].to(device))
|
97 |
+
else:
|
98 |
+
x[i] = self.self_attn[2*i+j](x[i])
|
99 |
+
|
100 |
+
# Dot Product between points and joints
|
101 |
+
logits = torch.einsum('bnc,bmc->bnm', self.k_proj(self.ln_1(x[0])), self.q_proj(self.ln_2(x[1]))) # (b, n, m)
|
102 |
+
|
103 |
+
if self.use_mask:
|
104 |
+
mask = data['joints_mask'].to(device)
|
105 |
+
logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
|
106 |
+
|
107 |
+
return logits
|
108 |
+
|
109 |
+
class AttendjointsDecoder_con_token(nn.Module):
|
110 |
+
def __init__(self,
|
111 |
+
width = 768,
|
112 |
+
layers = 4,
|
113 |
+
device = 'cuda',
|
114 |
+
dtype = torch.float32,
|
115 |
+
heads = 12,
|
116 |
+
init_scale: float = 0.25,
|
117 |
+
flash = False,
|
118 |
+
use_checkpoint = False,
|
119 |
+
qkv_bias = False,
|
120 |
+
num_freqs: int = 8,
|
121 |
+
include_pi: bool = True,
|
122 |
+
head_token_length =128,
|
123 |
+
separate = False,
|
124 |
+
use_mask = True):
|
125 |
+
|
126 |
+
super().__init__()
|
127 |
+
|
128 |
+
self.use_checkpoint = use_checkpoint
|
129 |
+
self.use_mask = use_mask
|
130 |
+
self.layer_norm = nn.LayerNorm(width)
|
131 |
+
self.head_token = nn.Parameter(torch.randn((1, 1, head_token_length), device=device, dtype=dtype) * 0.02)
|
132 |
+
self.tail_token = nn.Parameter(torch.randn((1, 1, head_token_length), device=device, dtype=dtype) * 0.02)
|
133 |
+
self.head_mlp = nn.ModuleList([
|
134 |
+
nn.Linear(width + head_token_length, 512, device=device, dtype=dtype),
|
135 |
+
nn.Linear(512, 512, device=device, dtype=dtype),
|
136 |
+
nn.Linear(512, width, device=device, dtype=dtype),
|
137 |
+
nn.LayerNorm(width)
|
138 |
+
|
139 |
+
])
|
140 |
+
self.tail_mlp = nn.ModuleList([
|
141 |
+
nn.Linear(width + head_token_length, 512, device=device, dtype=dtype),
|
142 |
+
nn.Linear(512, 512, device=device, dtype=dtype),
|
143 |
+
nn.Linear(512, width, device=device, dtype=dtype),
|
144 |
+
nn.LayerNorm(width)
|
145 |
+
])
|
146 |
+
|
147 |
+
self.self_attn = Transformer(
|
148 |
+
device=device,
|
149 |
+
dtype=dtype,
|
150 |
+
n_ctx=-1,
|
151 |
+
width=width,
|
152 |
+
layers=layers,
|
153 |
+
heads=heads,
|
154 |
+
init_scale=init_scale,
|
155 |
+
qkv_bias=qkv_bias,
|
156 |
+
flash=flash,
|
157 |
+
use_checkpoint=False,
|
158 |
+
)
|
159 |
+
self.separate = separate
|
160 |
+
self.normal_embedder = components_from_spherical_harmonics
|
161 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
162 |
+
self.joints_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
|
163 |
+
self.output_proj_joints = nn.Linear(width, width, device=device, dtype=dtype)
|
164 |
+
|
165 |
+
def forward(self, latents, data=None,device='cuda', downsample=None, dtype='float32'):
|
166 |
+
joints = data['joints'].to(device)
|
167 |
+
max_joints = max(data['joints_num'])
|
168 |
+
joints = joints[:, :max_joints, :3]
|
169 |
+
joints_embeds_fourier = self.fourier_embedder(joints)
|
170 |
+
joints_embeds = self.joints_proj(joints_embeds_fourier)
|
171 |
+
# Concatenate embeddings
|
172 |
+
x = torch.cat([joints_embeds, latents], dim=-2) # (b, max_joint+token_num, c)
|
173 |
+
# Pass through self-attention
|
174 |
+
if self.use_mask:
|
175 |
+
mask = data['mask'].to(device)
|
176 |
+
append_size = x.shape[1]-mask.shape[1] # the zero needs to append after mask
|
177 |
+
batch_size = mask.shape[0]
|
178 |
+
|
179 |
+
mask_extend = torch.ones((batch_size,append_size)).to(device)
|
180 |
+
mask = torch.cat([mask,mask_extend],dim=-1).to(device)
|
181 |
+
|
182 |
+
x = self.self_attn(x,mask)
|
183 |
+
else:
|
184 |
+
x = self.self_attn(x)
|
185 |
+
joints, _= x.split([joints_embeds.shape[1], latents.shape[1]], dim=1)
|
186 |
+
joints = self.output_proj_joints(self.layer_norm(joints))
|
187 |
+
joints_head = torch.concat([joints, self.head_token.repeat(joints.shape[0],joints.shape[1],1)], dim=-1)
|
188 |
+
joints_tail = torch.concat([joints, self.tail_token.repeat(joints.shape[0],joints.shape[1],1)], dim=-1)
|
189 |
+
for layer in self.head_mlp:
|
190 |
+
joints_head = layer(joints_head)
|
191 |
+
for layer in self.tail_mlp:
|
192 |
+
joints_tail = layer(joints_tail)
|
193 |
+
logits = torch.einsum('bik,bjk->bij', joints_head, joints_tail)
|
194 |
+
|
195 |
+
return logits
|
Anymate/models/diffusion.py
ADDED
@@ -0,0 +1,483 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
F"""
|
2 |
+
Adapted from: https://github.com/openai/openai/blob/55363aa496049423c37124b440e9e30366db3ed6/orc/orc/diffusion/vit.py
|
3 |
+
"""
|
4 |
+
|
5 |
+
import math
|
6 |
+
from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, Callable
|
7 |
+
|
8 |
+
import torch
|
9 |
+
import torch.nn as nn
|
10 |
+
|
11 |
+
from einops import repeat
|
12 |
+
from Anymate.utils.diffusion_utils import *
|
13 |
+
from ThirdParty.michelangelo.models.modules.transformer_blocks import Transformer, ResidualCrossAttentionBlock
|
14 |
+
|
15 |
+
from diffusers import DDPMScheduler, DDIMScheduler
|
16 |
+
from sklearn.cluster import DBSCAN
|
17 |
+
|
18 |
+
def init_linear(l, stddev):
|
19 |
+
nn.init.normal_(l.weight, std=stddev)
|
20 |
+
if l.bias is not None:
|
21 |
+
nn.init.constant_(l.bias, 0.0)
|
22 |
+
|
23 |
+
class projection_transformer(nn.Module):
|
24 |
+
def __init__(self, num_latents=16, width = 16, heads=8, dtype = torch.float32):
|
25 |
+
super().__init__()
|
26 |
+
self.num_latents = num_latents
|
27 |
+
self.query = nn.Parameter(torch.randn((num_latents, width), dtype=dtype) * 0.02)
|
28 |
+
|
29 |
+
self.cross_attn = ResidualCrossAttentionBlock(
|
30 |
+
device= 'cuda',
|
31 |
+
dtype=dtype,
|
32 |
+
width=width,
|
33 |
+
heads=heads,
|
34 |
+
init_scale=0.25,
|
35 |
+
qkv_bias=True,
|
36 |
+
flash=False,
|
37 |
+
)
|
38 |
+
self.output_proj = nn.Linear(width, width,dtype=dtype)
|
39 |
+
|
40 |
+
def forward(self, latents):
|
41 |
+
bs = latents.shape[0]
|
42 |
+
query = repeat(self.query, "m c -> b m c", b=bs)
|
43 |
+
embed = self.cross_attn(query, latents)
|
44 |
+
logits = self.output_proj(embed)
|
45 |
+
|
46 |
+
return logits
|
47 |
+
|
48 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
49 |
+
"""
|
50 |
+
Create sinusoidal timestep embeddings.
|
51 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
52 |
+
These may be fractional.
|
53 |
+
:param dim: the dimension of the output.
|
54 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
55 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
56 |
+
"""
|
57 |
+
half = dim // 2
|
58 |
+
freqs = torch.exp(
|
59 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
60 |
+
).to(device=timesteps.device)
|
61 |
+
args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
|
62 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
63 |
+
if dim % 2:
|
64 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
65 |
+
return embedding
|
66 |
+
|
67 |
+
class MultiheadAttention(nn.Module):
|
68 |
+
def __init__(
|
69 |
+
self,
|
70 |
+
*,
|
71 |
+
dtype: torch.dtype,
|
72 |
+
n_ctx: int,
|
73 |
+
width: int,
|
74 |
+
heads: int,
|
75 |
+
init_scale: float,
|
76 |
+
):
|
77 |
+
super().__init__()
|
78 |
+
self.n_ctx = n_ctx
|
79 |
+
self.width = width
|
80 |
+
self.heads = heads
|
81 |
+
self.c_qkv = nn.Linear(width, width * 3, dtype=dtype)
|
82 |
+
self.c_proj = nn.Linear(width, width, dtype=dtype)
|
83 |
+
self.attention = QKVMultiheadAttention(dtype=dtype, heads=heads, n_ctx=n_ctx)
|
84 |
+
init_linear(self.c_qkv, init_scale)
|
85 |
+
init_linear(self.c_proj, init_scale)
|
86 |
+
|
87 |
+
def forward(self, x):
|
88 |
+
x = self.c_qkv(x)
|
89 |
+
x = self.attention(x)
|
90 |
+
x = self.c_proj(x)
|
91 |
+
return x
|
92 |
+
|
93 |
+
class MLP(nn.Module):
|
94 |
+
def __init__(self, *, dtype: torch.dtype, width: int, init_scale: float):
|
95 |
+
super().__init__()
|
96 |
+
self.width = width
|
97 |
+
self.c_fc = nn.Linear(width, width * 4, dtype=dtype)
|
98 |
+
self.c_proj = nn.Linear(width * 4, width, dtype=dtype)
|
99 |
+
self.gelu = nn.GELU()
|
100 |
+
init_linear(self.c_fc, init_scale)
|
101 |
+
init_linear(self.c_proj, init_scale)
|
102 |
+
|
103 |
+
def forward(self, x):
|
104 |
+
return self.c_proj(self.gelu(self.c_fc(x)))
|
105 |
+
|
106 |
+
class QKVMultiheadAttention(nn.Module):
|
107 |
+
def __init__(self, *, dtype: torch.dtype, heads: int, n_ctx: int):
|
108 |
+
super().__init__()
|
109 |
+
self.dtype = dtype
|
110 |
+
self.heads = heads
|
111 |
+
self.n_ctx = n_ctx
|
112 |
+
|
113 |
+
def forward(self, qkv):
|
114 |
+
bs, n_ctx, width = qkv.shape
|
115 |
+
attn_ch = width // self.heads // 3
|
116 |
+
scale = 1 / math.sqrt(math.sqrt(attn_ch))
|
117 |
+
qkv = qkv.view(bs, n_ctx, self.heads, -1)
|
118 |
+
q, k, v = torch.split(qkv, attn_ch, dim=-1)
|
119 |
+
weight = torch.einsum(
|
120 |
+
"bthc,bshc->bhts", q * scale, k * scale
|
121 |
+
) # More stable with f16 than dividing afterwards
|
122 |
+
wdtype = weight.dtype
|
123 |
+
weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
|
124 |
+
return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
|
125 |
+
|
126 |
+
class ResidualAttentionBlock(nn.Module):
|
127 |
+
def __init__(
|
128 |
+
self,
|
129 |
+
*,
|
130 |
+
dtype: torch.dtype,
|
131 |
+
n_ctx: int,
|
132 |
+
width: int,
|
133 |
+
heads: int,
|
134 |
+
init_scale: float = 1.0,
|
135 |
+
):
|
136 |
+
super().__init__()
|
137 |
+
|
138 |
+
self.attn = MultiheadAttention(
|
139 |
+
dtype=dtype,
|
140 |
+
n_ctx=n_ctx,
|
141 |
+
width=width,
|
142 |
+
heads=heads,
|
143 |
+
init_scale=init_scale,
|
144 |
+
)
|
145 |
+
self.ln_1 = nn.LayerNorm(width, dtype=dtype)
|
146 |
+
self.mlp = MLP(dtype=dtype, width=width, init_scale=init_scale)
|
147 |
+
self.ln_2 = nn.LayerNorm(width, dtype=dtype)
|
148 |
+
|
149 |
+
def forward(self, x: torch.Tensor):
|
150 |
+
x = x + self.attn(self.ln_1(x))
|
151 |
+
x = x + self.mlp(self.ln_2(x))
|
152 |
+
return x
|
153 |
+
|
154 |
+
class Transformer(nn.Module):
|
155 |
+
def __init__(
|
156 |
+
self,
|
157 |
+
*,
|
158 |
+
dtype: torch.dtype,
|
159 |
+
n_ctx: int,
|
160 |
+
width: int,
|
161 |
+
layers: int,
|
162 |
+
heads: int,
|
163 |
+
init_scale: float = 0.25,
|
164 |
+
):
|
165 |
+
super().__init__()
|
166 |
+
self.n_ctx = n_ctx
|
167 |
+
self.width = width
|
168 |
+
self.layers = layers
|
169 |
+
init_scale = init_scale * math.sqrt(1.0 / width)
|
170 |
+
self.resblocks = nn.ModuleList(
|
171 |
+
[
|
172 |
+
ResidualAttentionBlock(
|
173 |
+
dtype=dtype,
|
174 |
+
n_ctx=n_ctx,
|
175 |
+
width=width,
|
176 |
+
heads=heads,
|
177 |
+
init_scale=init_scale,
|
178 |
+
)
|
179 |
+
for _ in range(layers)
|
180 |
+
]
|
181 |
+
)
|
182 |
+
|
183 |
+
def forward(self, x: torch.Tensor):
|
184 |
+
for block in self.resblocks:
|
185 |
+
x = block(x)
|
186 |
+
return x
|
187 |
+
|
188 |
+
class PointDiffusionTransformer(nn.Module):
|
189 |
+
def __init__(
|
190 |
+
self,
|
191 |
+
*,
|
192 |
+
dtype: torch.dtype,
|
193 |
+
input_channels: int = 3,
|
194 |
+
output_channels: int = 3,
|
195 |
+
n_ctx: int = 1024,
|
196 |
+
width: int = 768,
|
197 |
+
layers: int = 12,
|
198 |
+
heads: int = 8,
|
199 |
+
init_scale: float = 0.25,
|
200 |
+
time_token_cond: bool = True,
|
201 |
+
):
|
202 |
+
super().__init__()
|
203 |
+
self.input_channels = input_channels
|
204 |
+
self.output_channels = output_channels
|
205 |
+
self.n_ctx = n_ctx
|
206 |
+
self.time_token_cond = time_token_cond
|
207 |
+
self.time_embed = MLP(
|
208 |
+
dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
|
209 |
+
)
|
210 |
+
self.ln_pre = nn.LayerNorm(width, dtype=dtype)
|
211 |
+
self.backbone = Transformer(
|
212 |
+
dtype=dtype,
|
213 |
+
n_ctx=n_ctx + int(time_token_cond),
|
214 |
+
width=width,
|
215 |
+
layers=layers,
|
216 |
+
heads=heads,
|
217 |
+
init_scale=init_scale,
|
218 |
+
)
|
219 |
+
self.ln_post = nn.LayerNorm(width,dtype=dtype)
|
220 |
+
self.input_proj = nn.Linear(input_channels, width, dtype=dtype)
|
221 |
+
self.output_proj = nn.Linear(width, output_channels,dtype=dtype)
|
222 |
+
with torch.no_grad():
|
223 |
+
self.output_proj.weight.zero_()
|
224 |
+
self.output_proj.bias.zero_()
|
225 |
+
|
226 |
+
def forward(self, x: torch.Tensor, t: torch.Tensor):
|
227 |
+
"""
|
228 |
+
:param x: an [N x C x T] tensor.
|
229 |
+
:param t: an [N] tensor.
|
230 |
+
:return: an [N x C' x T] tensor.
|
231 |
+
"""
|
232 |
+
assert x.shape[-1] == self.n_ctx
|
233 |
+
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
234 |
+
return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])
|
235 |
+
|
236 |
+
def _forward_with_cond(
|
237 |
+
self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]
|
238 |
+
) -> torch.Tensor:
|
239 |
+
h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC
|
240 |
+
for emb, as_token in cond_as_token:
|
241 |
+
if not as_token:
|
242 |
+
h = h + emb[:, None]
|
243 |
+
extra_tokens = [
|
244 |
+
(emb[:, None] if len(emb.shape) == 2 else emb)
|
245 |
+
for emb, as_token in cond_as_token
|
246 |
+
if as_token
|
247 |
+
]
|
248 |
+
if len(extra_tokens):
|
249 |
+
h = torch.cat(extra_tokens + [h], dim=1)
|
250 |
+
|
251 |
+
h = self.ln_pre(h)
|
252 |
+
h = self.backbone(h)
|
253 |
+
h = self.ln_post(h)
|
254 |
+
if len(extra_tokens):
|
255 |
+
h = h[:, sum(h.shape[1] for h in extra_tokens) :]
|
256 |
+
h = self.output_proj(h)
|
257 |
+
return h.permute(0, 2, 1)
|
258 |
+
|
259 |
+
class Pointe_Diffusion(PointDiffusionTransformer):
|
260 |
+
'''
|
261 |
+
input: data: data dict
|
262 |
+
x: [N x C x T] tensor
|
263 |
+
t: [N] tensor
|
264 |
+
init:
|
265 |
+
n_ctx: int = 1024: context length
|
266 |
+
'''
|
267 |
+
def __init__(
|
268 |
+
self,
|
269 |
+
*,
|
270 |
+
device = 'cuda',
|
271 |
+
dtype = torch.float32,
|
272 |
+
encoder = 'miche',
|
273 |
+
n_ctx: int = 1024,
|
274 |
+
token_cond: bool = True,
|
275 |
+
cond_drop_prob: float = 0.1,
|
276 |
+
fix_emb: bool = False,
|
277 |
+
|
278 |
+
**kwargs,
|
279 |
+
):
|
280 |
+
super().__init__(dtype=dtype, n_ctx=n_ctx + int(token_cond), **kwargs)
|
281 |
+
self.n_ctx = n_ctx
|
282 |
+
self.token_cond = token_cond
|
283 |
+
# self.proj_transformer = projection_transformer(**kwargs)
|
284 |
+
self.encoder_name = encoder
|
285 |
+
self.cond_drop_prob = cond_drop_prob
|
286 |
+
self.fix_emb = fix_emb
|
287 |
+
self.dtype = dtype
|
288 |
+
self.inference = False
|
289 |
+
def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
290 |
+
with torch.no_grad():
|
291 |
+
return dict(embeddings=self.clip(batch_size, **model_kwargs))
|
292 |
+
|
293 |
+
def inference_mode(self,eps=0.03):
|
294 |
+
self.inference = True
|
295 |
+
|
296 |
+
def forward_func(
|
297 |
+
self,
|
298 |
+
latent: torch.Tensor,
|
299 |
+
data,
|
300 |
+
device='cuda',
|
301 |
+
downsample = False,
|
302 |
+
**kwargs,
|
303 |
+
):
|
304 |
+
t = kwargs['timesteps'].to(latent.device)
|
305 |
+
x = kwargs['noisy_joints'].to(latent.device)
|
306 |
+
assert x.shape[-1] == self.n_ctx, f"x shape: {x.shape}, n_ctx: {self.n_ctx}"
|
307 |
+
t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
|
308 |
+
|
309 |
+
if self.training:
|
310 |
+
mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
|
311 |
+
latent = latent * mask[:,None,None].to(latent.device)
|
312 |
+
|
313 |
+
latent = [(latent, self.token_cond), (t_embed, self.time_token_cond)]
|
314 |
+
return self._forward_with_cond(x, latent)
|
315 |
+
|
316 |
+
def forward(self, latent, data, device='cuda', downsample = False, **kwargs):
|
317 |
+
if self.inference == False:
|
318 |
+
return self.forward_func(latent, data, device, downsample, **kwargs)
|
319 |
+
else:
|
320 |
+
generator=torch.Generator(device='cpu')
|
321 |
+
scheduler = DDIMScheduler(100)
|
322 |
+
scheduler.set_timesteps(100)
|
323 |
+
points_shape = [1, self.n_ctx, 3]
|
324 |
+
|
325 |
+
points_noise = randn_tensor(points_shape, generator=generator)
|
326 |
+
points = points_noise.permute(0, 2, 1).to(latent.device)
|
327 |
+
for t in scheduler.timesteps:
|
328 |
+
with torch.no_grad():
|
329 |
+
time_steps = torch.ones(1, 1, dtype=torch.long) * t
|
330 |
+
model_output = self.forward_func(latent, data, noisy_joints=points, timesteps = time_steps)
|
331 |
+
|
332 |
+
points = scheduler.step(model_output, t, points, generator=generator).prev_sample
|
333 |
+
points = points.permute(0, 2, 1).cpu()
|
334 |
+
assert points.shape[0] == 1, "Inference mode only supports batch size 1"
|
335 |
+
joints = points[0].detach().cpu().numpy()
|
336 |
+
clustering = DBSCAN(eps=0.05, min_samples=1).fit(joints)
|
337 |
+
cluster_centers = []
|
338 |
+
for cluster in set(clustering.labels_):
|
339 |
+
cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
|
340 |
+
return cluster_centers
|
341 |
+
|
342 |
+
class Cross_Attention_Diffusion(nn.Module):
|
343 |
+
def __init__(self,
|
344 |
+
input_channels=3, output_channels=3,
|
345 |
+
num_z=16, num_x=1024, z_dim=768, x_dim=512,
|
346 |
+
num_blocks=6, num_compute_layers=4, num_heads=8,
|
347 |
+
mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
|
348 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,num_latents=16,
|
349 |
+
device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
|
350 |
+
use_projection = True,):
|
351 |
+
super().__init__()
|
352 |
+
self.use_projection = use_projection
|
353 |
+
self.device = device
|
354 |
+
self.num_z = num_z
|
355 |
+
self.num_x = num_x
|
356 |
+
self.z_dim = z_dim
|
357 |
+
if use_projection:
|
358 |
+
self.proj_transformer = projection_transformer(num_latents=num_latents, width=z_dim, heads=num_heads)
|
359 |
+
self.prev_latent = nn.Parameter(torch.zeros(1, self.num_z + num_latents + 1, z_dim))
|
360 |
+
self.inference = False
|
361 |
+
|
362 |
+
self.input_proj = nn.Linear(input_channels, x_dim)
|
363 |
+
self.ln_pre = nn.LayerNorm(x_dim)
|
364 |
+
self.z_init = nn.Parameter(torch.zeros(1, num_z, z_dim))
|
365 |
+
|
366 |
+
mlp_hidden_dim = int(z_dim * mlp_ratio)
|
367 |
+
self.time_embed = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim)
|
368 |
+
|
369 |
+
self.latent_mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
370 |
+
self.ln_latent = nn.LayerNorm(z_dim)
|
371 |
+
self.blocks = nn.ModuleList([
|
372 |
+
RCW_Block(z_dim, x_dim, num_compute_layers=num_compute_layers,
|
373 |
+
num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
|
374 |
+
drop=drop, attn_drop=attn_drop, drop_path=drop_path,
|
375 |
+
act_layer=act_layer, norm_layer=norm_layer)
|
376 |
+
for _ in range(num_blocks)
|
377 |
+
])
|
378 |
+
|
379 |
+
# output blocks
|
380 |
+
self.ln_post = nn.LayerNorm(x_dim)
|
381 |
+
self.output_proj = nn.Linear(x_dim, output_channels)
|
382 |
+
|
383 |
+
self.initialize_weights()
|
384 |
+
|
385 |
+
def initialize_weights(self):
|
386 |
+
nn.init.normal_(self.z_init, std=.02)
|
387 |
+
|
388 |
+
# initialize nn.Linear and nn.LayerNorm
|
389 |
+
self.apply(self._init_weights)
|
390 |
+
|
391 |
+
nn.init.constant_(self.ln_latent.weight, 0)
|
392 |
+
nn.init.constant_(self.ln_latent.bias, 0)
|
393 |
+
|
394 |
+
def _init_weights(self, m):
|
395 |
+
if isinstance(m, nn.Linear):
|
396 |
+
torch.nn.init.xavier_uniform_(m.weight)
|
397 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
398 |
+
nn.init.constant_(m.bias, 0)
|
399 |
+
elif isinstance(m, nn.LayerNorm):
|
400 |
+
nn.init.constant_(m.bias, 0)
|
401 |
+
nn.init.constant_(m.weight, 1.0)
|
402 |
+
|
403 |
+
def inference_mode(self,eps=0.03):
|
404 |
+
self.inference = True
|
405 |
+
|
406 |
+
def forward_func(self, latent, data, device='cuda', downsample = False, **kwargs):
|
407 |
+
"""
|
408 |
+
Forward pass of the model.
|
409 |
+
|
410 |
+
Parameters:
|
411 |
+
x: [B, num_x, C_in]
|
412 |
+
t: [B]
|
413 |
+
cond: [B, num_cond, C_latent]
|
414 |
+
prev_latent: [B, num_z + num_cond + 1, C_latent]
|
415 |
+
|
416 |
+
Returns:
|
417 |
+
x_denoised: [B, num_x, C_out]
|
418 |
+
z: [B, num_z + num_cond + 1, C_latent]
|
419 |
+
"""
|
420 |
+
t = kwargs['timesteps'].to(latent.device)
|
421 |
+
x = kwargs['noisy_joints'].to(latent.device)
|
422 |
+
x = x.permute(0, 2, 1)
|
423 |
+
B, num_x, _ = x.shape
|
424 |
+
if self.use_projection:
|
425 |
+
latent = self.proj_transformer(latent)
|
426 |
+
assert num_x == self.num_x, f"x shape: {x.shape}, num_x: {self.num_x}"
|
427 |
+
# if prev_latent is not None:
|
428 |
+
# _, num_z, _ = prev_latent.shape
|
429 |
+
# assert num_z == self.num_z + num_cond + 1
|
430 |
+
# else:
|
431 |
+
# prev_latent = torch.zeros(B, self.num_z + num_cond + 1, self.z_dim).to(x.device)
|
432 |
+
|
433 |
+
# timestep embedding, [B, 1, z_dim]
|
434 |
+
t_embed = self.time_embed(timestep_embedding(t, self.z_dim))
|
435 |
+
if t_embed.dim() == 2:
|
436 |
+
t_embed = t_embed.unsqueeze(1)
|
437 |
+
|
438 |
+
# project x -> [B, num_x, C_x]
|
439 |
+
x = self.input_proj(x)
|
440 |
+
x = self.ln_pre(x)
|
441 |
+
|
442 |
+
# latent self-conditioning
|
443 |
+
z = self.z_init.repeat(B, 1, 1) # [B, num_z, z_dim
|
444 |
+
z = torch.cat([z, latent, t_embed], dim=1) # [B, num_z + num_cond + 1, z_dim]
|
445 |
+
prev_latent = self.prev_latent + self.latent_mlp(self.prev_latent.detach())
|
446 |
+
z = z + (self.ln_latent(prev_latent))
|
447 |
+
|
448 |
+
# compute
|
449 |
+
for blk in self.blocks:
|
450 |
+
z, x = blk(z, x)
|
451 |
+
|
452 |
+
# output proj
|
453 |
+
x = self.ln_post(x)
|
454 |
+
x_denoised = self.output_proj(x)
|
455 |
+
return x_denoised.permute(0, 2, 1)
|
456 |
+
|
457 |
+
def forward(self, latent, data, device='cuda', downsample = False, **kwargs):
|
458 |
+
if self.inference == False:
|
459 |
+
return self.forward_func(latent, data, device, downsample, **kwargs)
|
460 |
+
else:
|
461 |
+
generator=torch.Generator(device='cpu')
|
462 |
+
scheduler = DDIMScheduler(100)
|
463 |
+
scheduler.set_timesteps(100)
|
464 |
+
points_shape = [1, self.num_x, 3]
|
465 |
+
|
466 |
+
points_noise = randn_tensor(points_shape, generator=generator)
|
467 |
+
points = points_noise.permute(0, 2, 1).to(latent.device)
|
468 |
+
for t in scheduler.timesteps:
|
469 |
+
with torch.no_grad():
|
470 |
+
time_steps = torch.ones(1, 1, dtype=torch.long) * t
|
471 |
+
time_steps = time_steps.to(latent.device)
|
472 |
+
model_output = self.forward_func(latent, data, noisy_joints=points, timesteps = time_steps)
|
473 |
+
|
474 |
+
points = scheduler.step(model_output, t, points, generator=generator).prev_sample
|
475 |
+
points = points.permute(0, 2, 1).cpu()
|
476 |
+
assert points.shape[0] == 1, "Inference mode only supports batch size 1"
|
477 |
+
joints = points[0].detach().cpu().numpy()
|
478 |
+
clustering = DBSCAN(eps=0.05, min_samples=1).fit(joints)
|
479 |
+
cluster_centers = []
|
480 |
+
for cluster in set(clustering.labels_):
|
481 |
+
cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
|
482 |
+
return cluster_centers
|
483 |
+
|
Anymate/models/joint.py
ADDED
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from ThirdParty.michelangelo.models.modules.embedder import FourierEmbedder
|
4 |
+
from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock
|
5 |
+
from ThirdParty.eg3d.training.networks_stylegan2 import Generator as StyleGAN2Backbone
|
6 |
+
from ThirdParty.eg3d.training.networks_stylegan2 import FullyConnectedLayer
|
7 |
+
from Anymate.utils.vol_utils import get_co, sample_from_planes, generate_planes
|
8 |
+
from einops import repeat
|
9 |
+
from sklearn.cluster import DBSCAN
|
10 |
+
from Anymate.utils.vol_utils import extract_keypoints
|
11 |
+
|
12 |
+
class TransformerDecoder(nn.Module):
|
13 |
+
def __init__(self,
|
14 |
+
num_latents = 96,
|
15 |
+
num_kv_latents = 257,
|
16 |
+
out_channels = 3,
|
17 |
+
width = 768,
|
18 |
+
layers = 7,
|
19 |
+
device = 'cuda',
|
20 |
+
dtype = torch.float32,
|
21 |
+
heads = 12,
|
22 |
+
init_scale: float = 0.25,
|
23 |
+
flash = False,
|
24 |
+
use_checkpoint = False,
|
25 |
+
qkv_bias = False):
|
26 |
+
|
27 |
+
super().__init__()
|
28 |
+
|
29 |
+
self.use_checkpoint = use_checkpoint
|
30 |
+
self.num_latents = num_latents
|
31 |
+
self.inference = False
|
32 |
+
self.eps = 0.03
|
33 |
+
|
34 |
+
self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
|
35 |
+
|
36 |
+
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
37 |
+
device=device,
|
38 |
+
dtype=dtype,
|
39 |
+
n_data=num_kv_latents,
|
40 |
+
width=width,
|
41 |
+
heads=heads,
|
42 |
+
init_scale=init_scale,
|
43 |
+
qkv_bias=qkv_bias,
|
44 |
+
flash=flash
|
45 |
+
)
|
46 |
+
|
47 |
+
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
48 |
+
self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
|
49 |
+
|
50 |
+
def inference_mode(self, eps=0.03, min_samples=1):
|
51 |
+
self.inference = True
|
52 |
+
self.eps = eps
|
53 |
+
self.min_samples = min_samples
|
54 |
+
|
55 |
+
def forward(self, latents, data=None, device='cuda', downsample=False, dtype=torch.float32):
|
56 |
+
|
57 |
+
bs = latents.shape[0]
|
58 |
+
query = repeat(self.query, "m c -> b m c", b=bs)
|
59 |
+
logits = self.cross_attn_decoder(query, latents)
|
60 |
+
logits = self.ln_post(logits)
|
61 |
+
logits = self.output_proj(logits)
|
62 |
+
if self.inference:
|
63 |
+
assert logits.shape[0] == 1, "Inference mode only supports batch size 1"
|
64 |
+
joints = logits[0].detach().cpu().numpy()
|
65 |
+
clustering = DBSCAN(eps=self.eps, min_samples=self.min_samples).fit(joints)
|
66 |
+
cluster_centers = []
|
67 |
+
for cluster in set(clustering.labels_):
|
68 |
+
cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
|
69 |
+
return cluster_centers
|
70 |
+
return logits
|
71 |
+
|
72 |
+
|
73 |
+
class ImplicitTransformerDecoder(nn.Module):
|
74 |
+
|
75 |
+
def __init__(self, *,
|
76 |
+
device = 'cuda',
|
77 |
+
dtype = torch.float32,
|
78 |
+
num_latents = 257,
|
79 |
+
out_channels = 1,
|
80 |
+
width = 768,
|
81 |
+
heads = 12,
|
82 |
+
num_freqs: int = 8,
|
83 |
+
include_pi: bool = True,
|
84 |
+
init_scale: float = 0.25,
|
85 |
+
qkv_bias: bool = False,
|
86 |
+
flash: bool = False,
|
87 |
+
use_checkpoint: bool = False):
|
88 |
+
|
89 |
+
super().__init__()
|
90 |
+
|
91 |
+
self.use_checkpoint = use_checkpoint
|
92 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
93 |
+
self.inference = False
|
94 |
+
|
95 |
+
self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
|
96 |
+
|
97 |
+
self.cross_attn_decoder = ResidualCrossAttentionBlock(
|
98 |
+
device=device,
|
99 |
+
dtype=dtype,
|
100 |
+
n_data=num_latents,
|
101 |
+
width=width,
|
102 |
+
heads=heads,
|
103 |
+
init_scale=init_scale,
|
104 |
+
qkv_bias=qkv_bias,
|
105 |
+
flash=flash
|
106 |
+
)
|
107 |
+
|
108 |
+
self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
109 |
+
self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
|
110 |
+
|
111 |
+
# self.queries = get_vol().to(device)
|
112 |
+
|
113 |
+
def inference_mode(self):
|
114 |
+
self.inference = True
|
115 |
+
|
116 |
+
def forward(self, latents: torch.FloatTensor, data=None, device='cuda', downsample=False):
|
117 |
+
bs = latents.shape[0]
|
118 |
+
# queries = repeat(self.queries, "m c -> b m c", b=bs)
|
119 |
+
out = []
|
120 |
+
for b in range(bs):
|
121 |
+
queries = get_co(data['vox'][b]).to(device).unsqueeze(0)
|
122 |
+
if downsample and data['vox'][b].shape[0] > 50000:
|
123 |
+
# random sample
|
124 |
+
idx = torch.randperm(data['vox'][b].shape[0])[:50000]
|
125 |
+
queries = queries[:, idx]
|
126 |
+
queries = self.query_proj(self.fourier_embedder(queries))
|
127 |
+
x = self.cross_attn_decoder(queries, latents[b:b+1])
|
128 |
+
x = self.ln_post(x)
|
129 |
+
x = self.output_proj(x)
|
130 |
+
if downsample and data['vox'][b].shape[0] > 50000:
|
131 |
+
out.append((x.squeeze(0), idx))
|
132 |
+
else:
|
133 |
+
out.append(x.squeeze(0))
|
134 |
+
if self.inference:
|
135 |
+
assert len(out) == 1, "Inference mode only supports batch size 1"
|
136 |
+
return extract_keypoints(out[0], data['vox'][0])
|
137 |
+
|
138 |
+
return out
|
139 |
+
|
140 |
+
|
141 |
+
class TriPlaneDecoder(torch.nn.Module):
|
142 |
+
def __init__(self,
|
143 |
+
z_dim = 768, # Input latent (Z) dimensionality.
|
144 |
+
c_dim = 0, # Conditioning label (C) dimensionality.
|
145 |
+
w_dim = 768, # Intermediate latent (W) dimensionality.
|
146 |
+
# img_resolution, # Output resolution.
|
147 |
+
# img_channels, # Number of output color channels.
|
148 |
+
# sr_num_fp16_res = 0,
|
149 |
+
mapping_kwargs = {'num_layers': 2}, # Arguments for MappingNetwork.
|
150 |
+
# rendering_kwargs = {},
|
151 |
+
# sr_kwargs = {},
|
152 |
+
synthesis_kwargs = {'num_fp16_res': 0, 'conv_clamp': None, 'fused_modconv_default': 'inference_only'}, # Arguments for SynthesisNetwork.
|
153 |
+
):
|
154 |
+
super().__init__()
|
155 |
+
self.z_dim=z_dim
|
156 |
+
self.c_dim=c_dim
|
157 |
+
self.w_dim=w_dim
|
158 |
+
# self.img_resolution=img_resolution
|
159 |
+
# self.img_channels=img_channels
|
160 |
+
# self.renderer = ImportanceRenderer()
|
161 |
+
# self.ray_sampler = RaySampler()
|
162 |
+
self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
|
163 |
+
# self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs)
|
164 |
+
self.decoder = OSGDecoder(32, {'decoder_output_dim': 0})
|
165 |
+
self.inference = False
|
166 |
+
# self.neural_rendering_resolution = 64
|
167 |
+
# self.rendering_kwargs = rendering_kwargs
|
168 |
+
|
169 |
+
self._last_planes = None
|
170 |
+
self.plane_axes = generate_planes()
|
171 |
+
|
172 |
+
def mapping(self, z, c=None, truncation_psi=1, truncation_cutoff=None, update_emas=False):
|
173 |
+
# if self.rendering_kwargs['c_gen_conditioning_zero']:
|
174 |
+
# c = torch.zeros_like(c)
|
175 |
+
# return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
|
176 |
+
return self.backbone.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
|
177 |
+
|
178 |
+
def synthesis(self, ws, c=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
|
179 |
+
# cam2world_matrix = c[:, :16].view(-1, 4, 4)
|
180 |
+
# intrinsics = c[:, 16:25].view(-1, 3, 3)
|
181 |
+
|
182 |
+
# if neural_rendering_resolution is None:
|
183 |
+
# neural_rendering_resolution = self.neural_rendering_resolution
|
184 |
+
# else:
|
185 |
+
# self.neural_rendering_resolution = neural_rendering_resolution
|
186 |
+
|
187 |
+
# Create a batch of rays for volume rendering
|
188 |
+
# ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)
|
189 |
+
|
190 |
+
# Create triplanes by running StyleGAN backbone
|
191 |
+
# N, M, _ = ray_origins.shape
|
192 |
+
if use_cached_backbone and self._last_planes is not None:
|
193 |
+
planes = self._last_planes
|
194 |
+
else:
|
195 |
+
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
|
196 |
+
if cache_backbone:
|
197 |
+
self._last_planes = planes
|
198 |
+
|
199 |
+
# Reshape output into three 32-channel planes
|
200 |
+
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
|
201 |
+
return planes
|
202 |
+
|
203 |
+
# Perform volume rendering
|
204 |
+
feature_samples, depth_samples, weights_samples = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last
|
205 |
+
|
206 |
+
# Reshape into 'raw' neural-rendered image
|
207 |
+
H = W = self.neural_rendering_resolution
|
208 |
+
feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
|
209 |
+
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
|
210 |
+
|
211 |
+
# Run superresolution to get final image
|
212 |
+
rgb_image = feature_image[:, :3]
|
213 |
+
sr_image = self.superresolution(rgb_image, feature_image, ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
|
214 |
+
|
215 |
+
return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image}
|
216 |
+
|
217 |
+
def sample(self, coordinates, directions, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
|
218 |
+
# Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
|
219 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
|
220 |
+
planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
|
221 |
+
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
|
222 |
+
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
|
223 |
+
|
224 |
+
def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
|
225 |
+
# Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
|
226 |
+
planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
|
227 |
+
planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
|
228 |
+
return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
|
229 |
+
|
230 |
+
def inference_mode(self):
|
231 |
+
self.inference = True
|
232 |
+
|
233 |
+
def forward(self, z, data=None, device='cuda', downsample=False, c=None, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
|
234 |
+
# Render a batch of generated images.
|
235 |
+
assert z.shape[-1] == self.z_dim
|
236 |
+
ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
|
237 |
+
planes = self.synthesis(ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)
|
238 |
+
bs = planes.shape[0]
|
239 |
+
logits = []
|
240 |
+
for b in range(bs):
|
241 |
+
queries = get_co(data['vox'][b]).to(device).unsqueeze(0)
|
242 |
+
if downsample and data['vox'][b].shape[0] > 50000:
|
243 |
+
# random sample
|
244 |
+
idx = torch.randperm(data['vox'][b].shape[0])[:50000]
|
245 |
+
queries = queries[:, idx]
|
246 |
+
out = sample_from_planes(self.plane_axes.to(device), planes[b:b+1], queries)
|
247 |
+
out = self.decoder(out)
|
248 |
+
if downsample and data['vox'][b].shape[0] > 50000:
|
249 |
+
logits.append((out.squeeze(0), idx))
|
250 |
+
else:
|
251 |
+
logits.append(out.squeeze(0))
|
252 |
+
if self.inference:
|
253 |
+
assert len(logits) == 1, "Inference mode only supports batch size 1"
|
254 |
+
return extract_keypoints(logits[0], data['vox'][0])
|
255 |
+
return logits
|
256 |
+
|
257 |
+
|
258 |
+
class OSGDecoder(torch.nn.Module):
|
259 |
+
def __init__(self, n_features, options):
|
260 |
+
super().__init__()
|
261 |
+
self.hidden_dim = 64
|
262 |
+
|
263 |
+
self.net = torch.nn.Sequential(
|
264 |
+
FullyConnectedLayer(n_features, self.hidden_dim),
|
265 |
+
torch.nn.Softplus(),
|
266 |
+
FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'])
|
267 |
+
)
|
268 |
+
|
269 |
+
def forward(self, sampled_features, ray_directions=None):
|
270 |
+
# Aggregate features
|
271 |
+
sampled_features = sampled_features.mean(1)
|
272 |
+
x = sampled_features
|
273 |
+
|
274 |
+
N, M, C = x.shape
|
275 |
+
x = x.view(N*M, C)
|
276 |
+
|
277 |
+
x = self.net(x)
|
278 |
+
x = x.view(N, M, -1)
|
279 |
+
return x
|
280 |
+
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
|
281 |
+
sigma = x[..., 0:1]
|
282 |
+
return {'rgb': rgb, 'sigma': sigma}
|
Anymate/models/skin.py
ADDED
@@ -0,0 +1,309 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock, Transformer
|
4 |
+
from ThirdParty.michelangelo.models.modules.embedder import components_from_spherical_harmonics, FourierEmbedder
|
5 |
+
from einops import repeat, rearrange
|
6 |
+
|
7 |
+
class AttendjointsDecoder_combine(nn.Module):
|
8 |
+
def __init__(self,
|
9 |
+
width = 768,
|
10 |
+
layers = 2,
|
11 |
+
device = 'cuda',
|
12 |
+
dtype = torch.float32,
|
13 |
+
heads = 12,
|
14 |
+
init_scale: float = 0.25,
|
15 |
+
flash = False,
|
16 |
+
use_checkpoint = False,
|
17 |
+
qkv_bias = False,
|
18 |
+
num_freqs: int = 8,
|
19 |
+
include_pi: bool = True,
|
20 |
+
separate = False,
|
21 |
+
use_mask = True,
|
22 |
+
use_bone = True,
|
23 |
+
inference= False):
|
24 |
+
|
25 |
+
super().__init__()
|
26 |
+
self.inference = inference
|
27 |
+
self.use_checkpoint = use_checkpoint
|
28 |
+
self.separate = separate
|
29 |
+
self.use_mask = use_mask
|
30 |
+
# self.num_latents = num_latents
|
31 |
+
|
32 |
+
# self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
|
33 |
+
|
34 |
+
self.normal_embedder = components_from_spherical_harmonics
|
35 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
36 |
+
self.bone_proj = None if not use_bone else nn.Linear(self.fourier_embedder.out_dim * 2, width, device=device, dtype=dtype)
|
37 |
+
self.use_bone = use_bone
|
38 |
+
|
39 |
+
if not self.separate:
|
40 |
+
self.co_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
|
41 |
+
self.normal_proj = nn.Linear(25, width, device=device, dtype=dtype)
|
42 |
+
else:
|
43 |
+
self.pc_proj = nn.Linear(self.fourier_embedder.out_dim + 25, width, device=device, dtype=dtype)
|
44 |
+
|
45 |
+
|
46 |
+
# self.proj_attn = nn.Linear(width, width, device=device, dtype=dtype)
|
47 |
+
|
48 |
+
self.cross_attn = nn.ModuleList([ResidualCrossAttentionBlock(
|
49 |
+
device=device,
|
50 |
+
dtype=dtype,
|
51 |
+
width=width,
|
52 |
+
heads=heads,
|
53 |
+
init_scale=init_scale,
|
54 |
+
qkv_bias=qkv_bias,
|
55 |
+
flash=flash,
|
56 |
+
) for _ in range(layers)])
|
57 |
+
|
58 |
+
self.cross_attn_joint = nn.ModuleList([ResidualCrossAttentionBlock(
|
59 |
+
device=device,
|
60 |
+
dtype=dtype,
|
61 |
+
width=width,
|
62 |
+
heads=heads,
|
63 |
+
init_scale=init_scale,
|
64 |
+
qkv_bias=qkv_bias,
|
65 |
+
flash=flash,
|
66 |
+
) for _ in range(layers)])
|
67 |
+
|
68 |
+
# self.joint_embed_proj = nn.ModuleList([nn.Linear(width, width, device=device, dtype=dtype) for _ in range(layers)])
|
69 |
+
|
70 |
+
|
71 |
+
self.q_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
72 |
+
self.k_proj = nn.Linear(width, width, device=device, dtype=dtype)
|
73 |
+
self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
|
74 |
+
self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
|
75 |
+
|
76 |
+
# self.last_cross_attn = ResidualCrossAttentionBlock(
|
77 |
+
# device=device,
|
78 |
+
# dtype=dtype,
|
79 |
+
# width=width,
|
80 |
+
# heads=heads,
|
81 |
+
# init_scale=init_scale,
|
82 |
+
# qkv_bias=qkv_bias,
|
83 |
+
# flash=flash,
|
84 |
+
# )
|
85 |
+
# self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
|
86 |
+
# self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
|
87 |
+
|
88 |
+
def forward(self, latents, data=None, device='cuda', downsample=None, dtype=torch.float32):
|
89 |
+
joints = data['bones'].to(device) if self.use_bone else data['joints'].to(device)
|
90 |
+
max_joints = max(data['bones_num']) if self.use_bone else max(data['joints_num'])
|
91 |
+
mask = data['bones_mask'].to(device) if self.use_bone else data['joints_mask']
|
92 |
+
|
93 |
+
pc = data['vertices'][..., 0:3].to(device) if self.inference else data['points_cloud'][..., 0:3].to(device)
|
94 |
+
feats = data['vertices'][..., 3:].to(device) if self.inference else data['points_cloud'][..., 3:].to(device)
|
95 |
+
|
96 |
+
if downsample and not self.inference:
|
97 |
+
# random sample
|
98 |
+
idx = torch.randperm(pc.shape[1])[:downsample].to(device)
|
99 |
+
pc = pc[:, idx]
|
100 |
+
feats = feats[:, idx]
|
101 |
+
|
102 |
+
# Embed the input data
|
103 |
+
co_embeds = self.fourier_embedder(pc)
|
104 |
+
if not self.separate:
|
105 |
+
co_embeds = self.co_proj(co_embeds)
|
106 |
+
|
107 |
+
if self.use_bone:
|
108 |
+
# joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints*2:2, :3]), self.fourier_embedder(joints[:,1:max_joints*2:2, :3])), dim=-1)
|
109 |
+
joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints,:3]), self.fourier_embedder(joints[:,:max_joints, 3:])), dim=-1)
|
110 |
+
else:
|
111 |
+
joints_fourier = self.fourier_embedder(joints[:,:max_joints, :3])
|
112 |
+
|
113 |
+
if not self.separate:
|
114 |
+
joints_embeds = self.co_proj(joints_fourier) if not self.use_bone else self.bone_proj(joints_fourier)
|
115 |
+
|
116 |
+
normal_embeds = self.normal_proj(self.normal_embedder(feats)) if not self.separate else self.normal_embedder(feats)
|
117 |
+
|
118 |
+
if not self.separate:
|
119 |
+
pc_embeds = co_embeds + normal_embeds
|
120 |
+
else:
|
121 |
+
joints_embeds = self.co_proj(joints_fourier.to(dtype)) if not self.use_bone else self.bone_proj(joints_fourier.to(dtype))
|
122 |
+
pc_embeds = self.pc_proj(torch.cat([co_embeds.to(dtype), normal_embeds.to(dtype)], dim=-1))
|
123 |
+
|
124 |
+
pc_num = pc_embeds.shape[-2]
|
125 |
+
joints_num = joints_embeds.shape[-2]
|
126 |
+
x = torch.cat([pc_embeds, joints_embeds], dim=-2)
|
127 |
+
for i, layer in enumerate(self.cross_attn):
|
128 |
+
|
129 |
+
x = layer(x, latents)
|
130 |
+
if self.use_mask:
|
131 |
+
x = self.cross_attn_joint[i](x, x[:, pc_num:], mask=mask.to(device))
|
132 |
+
else:
|
133 |
+
x = self.cross_attn_joint[i](x, x[:, pc_num:])
|
134 |
+
pc_embeds, joints_embeds = x.split([pc_num, joints_num], dim=1)
|
135 |
+
|
136 |
+
logits = torch.einsum('bnc,bmc->bnm', self.k_proj(self.ln_1(pc_embeds)), self.q_proj(self.ln_2(joints_embeds))) # (b, n, m)
|
137 |
+
|
138 |
+
if self.use_mask:
|
139 |
+
logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
|
140 |
+
|
141 |
+
if downsample and not self.inference:
|
142 |
+
return logits, idx
|
143 |
+
|
144 |
+
return logits
|
145 |
+
|
146 |
+
class AttendjointsDecoder_multi(nn.Module):
|
147 |
+
def __init__(self,
|
148 |
+
# num_latents = 64,
|
149 |
+
# num_kv_latents = 257,
|
150 |
+
# out_channels = 3,
|
151 |
+
width = 768,
|
152 |
+
layers = 4,
|
153 |
+
device = 'cuda',
|
154 |
+
dtype = torch.float32,
|
155 |
+
heads = 12,
|
156 |
+
init_scale: float = 0.25,
|
157 |
+
flash = False,
|
158 |
+
use_checkpoint = False,
|
159 |
+
qkv_bias = False,
|
160 |
+
num_freqs: int = 8,
|
161 |
+
concat_num: int = 512,
|
162 |
+
include_pi: bool = True,
|
163 |
+
separate = False,
|
164 |
+
use_mask = True,
|
165 |
+
inference_with_repeat=False,
|
166 |
+
use_bone = True,
|
167 |
+
inference = False):
|
168 |
+
|
169 |
+
super().__init__()
|
170 |
+
|
171 |
+
self.use_checkpoint = use_checkpoint
|
172 |
+
self.use_mask = use_mask
|
173 |
+
self.inference_with_repeat = inference_with_repeat
|
174 |
+
self.inference = inference
|
175 |
+
|
176 |
+
self.self_attn = Transformer(
|
177 |
+
device=device,
|
178 |
+
dtype=dtype,
|
179 |
+
n_ctx=-1,
|
180 |
+
width=width,
|
181 |
+
layers=layers,
|
182 |
+
heads=heads,
|
183 |
+
init_scale=init_scale,
|
184 |
+
qkv_bias=qkv_bias,
|
185 |
+
flash=flash,
|
186 |
+
use_checkpoint=False,
|
187 |
+
|
188 |
+
)
|
189 |
+
self.concat_number = concat_num
|
190 |
+
self.separate = separate
|
191 |
+
self.normal_embedder = components_from_spherical_harmonics
|
192 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
193 |
+
self.bone_proj = None if not use_bone else nn.Linear(self.fourier_embedder.out_dim * 2, width, device=device, dtype=dtype)
|
194 |
+
self.use_bone = use_bone
|
195 |
+
if not self.separate:
|
196 |
+
self.co_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
|
197 |
+
self.normal_proj = nn.Linear(25, width, device=device, dtype=dtype)
|
198 |
+
else:
|
199 |
+
self.pc_proj = nn.Linear(self.fourier_embedder.out_dim + 25, width, device=device, dtype=dtype)
|
200 |
+
|
201 |
+
# self.proj_attn = nn.Linear(width, width, device=device, dtype=dtype)
|
202 |
+
|
203 |
+
# self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
|
204 |
+
self.output_proj_joints = nn.Linear(width, width, device=device, dtype=dtype)
|
205 |
+
self.output_proj_points = nn.Linear(width, width, device=device, dtype=dtype)
|
206 |
+
self.layer_norm = nn.LayerNorm(width)
|
207 |
+
|
208 |
+
# def inference(self, latents, data=None,device='cuda', dtype='float32', use_mask=False):
|
209 |
+
def inference_mode(self):
|
210 |
+
self.inference = True
|
211 |
+
|
212 |
+
def forward(self, latents, data=None,device='cuda', downsample=None, dtype='float32'):
|
213 |
+
joints = data['bones'].to(device) if self.use_bone else data['joints'].to(device)
|
214 |
+
max_joints = max(data['bones_num']) if self.use_bone else max(data['joints_num'])
|
215 |
+
|
216 |
+
pc = data['points_cloud'][..., 0:3].to(device)
|
217 |
+
feats = data['points_cloud'][..., 3:].to(device)
|
218 |
+
|
219 |
+
if downsample:
|
220 |
+
# random sample
|
221 |
+
idx = torch.randperm(pc.shape[1])[:downsample].to(device)
|
222 |
+
pc = pc[:, idx]
|
223 |
+
feats = feats[:, idx]
|
224 |
+
|
225 |
+
bs = pc.shape[1]//self.concat_number
|
226 |
+
|
227 |
+
# Embed the input data
|
228 |
+
if self.use_bone:
|
229 |
+
# joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints*2:2, :3]), self.fourier_embedder(joints[:,1:max_joints*2:2, :3])), dim=-1)
|
230 |
+
joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints,:3]), self.fourier_embedder(joints[:,:max_joints, 3:])), dim=-1)
|
231 |
+
else:
|
232 |
+
joints_fourier = self.fourier_embedder(joints[:,:max_joints, :3])
|
233 |
+
|
234 |
+
if self.separate:
|
235 |
+
joints_embeds = self.co_proj(joints_fourier.to(dtype)) if not self.use_bone else self.bone_proj(joints_fourier.to(dtype))
|
236 |
+
points_embeds = self.fourier_embedder(pc)
|
237 |
+
normal_embeds = self.normal_embedder(feats)
|
238 |
+
points = self.pc_proj(torch.cat([points_embeds, normal_embeds], dim=-1))
|
239 |
+
else:
|
240 |
+
joints_embeds = self.co_proj(joints_fourier) if not self.use_bone else self.bone_proj(joints_fourier)
|
241 |
+
co_embeds = self.fourier_embedder(pc)
|
242 |
+
co_embeds = self.co_proj(co_embeds)
|
243 |
+
# Embed the normals
|
244 |
+
normal_embeds = self.normal_embedder(feats)
|
245 |
+
normal_embeds = self.normal_proj(normal_embeds) # (b, n, c)
|
246 |
+
points = (co_embeds + normal_embeds)
|
247 |
+
|
248 |
+
repeated_latents = repeat(latents, "b m c -> b n m c", n=bs)
|
249 |
+
repeated_joints = repeat(joints_embeds, "b m c -> b n m c", n=bs)
|
250 |
+
points = points.reshape( latents.shape[0], bs, self.concat_number, -1)
|
251 |
+
|
252 |
+
# Concatenate embeddings
|
253 |
+
x = torch.cat([repeated_joints, points, repeated_latents], dim=-2) # (b, bs, concat_number+latent_num+joints_num, c)
|
254 |
+
|
255 |
+
# Pass through self-attention
|
256 |
+
if self.use_mask:
|
257 |
+
mask = data['bones_mask'].to(device)
|
258 |
+
append_size = x.shape[2]-mask.shape[1] # the zero needs to append after mask
|
259 |
+
batch_size = mask.shape[0]
|
260 |
+
mask_extend = torch.ones((batch_size,append_size)).to(device)
|
261 |
+
mask = torch.cat([mask,mask_extend],dim=-1).repeat(bs,1).to(device)
|
262 |
+
x = rearrange(x, "b n m c -> (b n) m c")
|
263 |
+
x = self.self_attn(x,mask)
|
264 |
+
else:
|
265 |
+
x = rearrange(x, "b n m c -> (b n) m c")
|
266 |
+
x = self.self_attn(x)
|
267 |
+
joints, points, _ = x.split([joints_embeds.shape[1],self.concat_number, latents.shape[1]], dim=1)
|
268 |
+
joints = self.output_proj_joints(self.layer_norm(joints))
|
269 |
+
points = self.output_proj_points(self.layer_norm(points))
|
270 |
+
|
271 |
+
logits = torch.einsum('bik,bjk->bij', points, joints)
|
272 |
+
logits = rearrange(logits, '(b n) m c -> b (n m) c', b=pc.shape[0],n=bs) # (b, n, c)
|
273 |
+
|
274 |
+
if self.use_mask:
|
275 |
+
mask = data['bones_mask'].to(device)
|
276 |
+
logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
|
277 |
+
|
278 |
+
if self.inference:
|
279 |
+
vertices = data['vertice']
|
280 |
+
points_cloud = data['points_cloud'][0,..., 0:3].to(device)
|
281 |
+
vertices_exp = vertices[0,...,:3] # (batch_size, num_vertices, 1, 3)
|
282 |
+
logits = compute_nearest_points(vertices_exp, points_cloud, logits[0], device)
|
283 |
+
|
284 |
+
if downsample:
|
285 |
+
return logits, idx
|
286 |
+
|
287 |
+
return logits
|
288 |
+
|
289 |
+
def compute_nearest_points(vertices, points, logits, device, batch_size=1024):
|
290 |
+
# vertices: [N, 3]
|
291 |
+
# points: [M, 3]
|
292 |
+
# logits: [M, K] (K is the number of skinning weights)
|
293 |
+
|
294 |
+
num_vertices = vertices.shape[0]
|
295 |
+
# Initialize the output tensor for skinning weights
|
296 |
+
skin_predict = torch.zeros((num_vertices, logits.shape[1]), device=device)
|
297 |
+
|
298 |
+
# Split vertices into batches
|
299 |
+
for i in range(0, num_vertices, batch_size):
|
300 |
+
|
301 |
+
batch_vertices = vertices[i:i+batch_size] # [batch_size, 3]
|
302 |
+
vertices_exp = batch_vertices.unsqueeze(1) # [batch_size, 1, 3]
|
303 |
+
points_exp = points.unsqueeze(0) # [1, num_points, 3]
|
304 |
+
distances = torch.sum((vertices_exp - points_exp) ** 2, dim=-1) # [batch_size, num_points]
|
305 |
+
nearest_idx = torch.argmin(distances, dim=-1) # [batch_size]
|
306 |
+
skin_predict_batch = logits[nearest_idx] # [batch_size, K]
|
307 |
+
skin_predict[i:i+batch_size] = skin_predict_batch
|
308 |
+
|
309 |
+
return skin_predict
|
Anymate/tmp/.gitkeep
ADDED
File without changes
|
Anymate/utils/dataset_utils.py
ADDED
@@ -0,0 +1,129 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
import trimesh
|
4 |
+
from ThirdParty.Rignet_utils import binvox_rw
|
5 |
+
|
6 |
+
|
7 |
+
def sparse_to_index(sparse_matrix):
|
8 |
+
index = []
|
9 |
+
weight = []
|
10 |
+
for j in range(len(sparse_matrix)):
|
11 |
+
if sparse_matrix[j] > 0:
|
12 |
+
index.append(j)
|
13 |
+
weight.append(sparse_matrix[j])
|
14 |
+
|
15 |
+
return index, weight
|
16 |
+
|
17 |
+
def index_to_sparse(index, weight, shape):
|
18 |
+
sparse_matrix = np.zeros([shape[0], shape[1], shape[2]+1])
|
19 |
+
|
20 |
+
row_indices, col_indices = np.meshgrid(np.arange(sparse_matrix.shape[0]), np.arange(sparse_matrix.shape[1]), indexing='ij')
|
21 |
+
|
22 |
+
row_indices = np.expand_dims(row_indices, axis=-1)
|
23 |
+
col_indices = np.expand_dims(col_indices, axis=-1)
|
24 |
+
|
25 |
+
sparse_matrix[row_indices, col_indices, index] = weight
|
26 |
+
|
27 |
+
|
28 |
+
return torch.from_numpy(sparse_matrix[:, :, :-1])
|
29 |
+
|
30 |
+
def index_to_sparse_con(index, shape):
|
31 |
+
|
32 |
+
sparse_matrix = np.zeros([shape[0], shape[1], shape[2]+1],dtype=np.int8)
|
33 |
+
row_indices, col_indices = np.meshgrid(np.arange(sparse_matrix.shape[0]), np.arange(sparse_matrix.shape[1]), indexing='ij')
|
34 |
+
|
35 |
+
row_indices = np.expand_dims(row_indices, axis=-1)
|
36 |
+
col_indices = np.expand_dims(col_indices, axis=-1)
|
37 |
+
|
38 |
+
sparse_matrix[row_indices, col_indices, index] = 1
|
39 |
+
|
40 |
+
|
41 |
+
return torch.from_numpy(sparse_matrix[:, :, :-1])
|
42 |
+
|
43 |
+
def create_mask(n, max_len=64):
|
44 |
+
mask = torch.zeros(max_len, dtype=torch.bool)
|
45 |
+
mask[:n] = 1
|
46 |
+
return mask
|
47 |
+
|
48 |
+
def reduce(vox):
|
49 |
+
new_data = np.zeros((vox.dims[0] // 2, vox.dims[1] // 2, vox.dims[2] // 2)).astype(bool)
|
50 |
+
new_data = np.logical_or(new_data, vox.data[::2, ::2, ::2])
|
51 |
+
new_data = np.logical_or(new_data, vox.data[1::2, ::2, ::2])
|
52 |
+
new_data = np.logical_or(new_data, vox.data[::2, 1::2, ::2])
|
53 |
+
new_data = np.logical_or(new_data, vox.data[::2, ::2, 1::2])
|
54 |
+
new_data = np.logical_or(new_data, vox.data[1::2, 1::2, ::2])
|
55 |
+
new_data = np.logical_or(new_data, vox.data[1::2, ::2, 1::2])
|
56 |
+
new_data = np.logical_or(new_data, vox.data[::2, 1::2, 1::2])
|
57 |
+
new_data = np.logical_or(new_data, vox.data[1::2, 1::2, 1::2])
|
58 |
+
# dilate the new voxel
|
59 |
+
new_data[:-1, :, :] = np.logical_or(new_data[:-1, :, :], new_data[1:, :, :])
|
60 |
+
new_data[:, :-1, :] = np.logical_or(new_data[:, :-1, :], new_data[:, 1:, :])
|
61 |
+
new_data[:, :, :-1] = np.logical_or(new_data[:, :, :-1], new_data[:, :, 1:])
|
62 |
+
return binvox_rw.Voxels(new_data, new_data.shape, vox.translate, vox.scale, vox.axis_order)
|
63 |
+
|
64 |
+
def align(vox, y_max):
|
65 |
+
new_data = np.zeros(vox.dims).astype(bool)
|
66 |
+
ind = np.argwhere(vox.data)
|
67 |
+
ind = ind + (np.array(vox.translate) - np.array([-0.5, -0.5 * (1 - y_max), -0.5])) * vox.dims[0]
|
68 |
+
# round to the nearest integer
|
69 |
+
# ind = np.round(ind).astype(int)
|
70 |
+
ind = np.ceil(ind).astype(int)
|
71 |
+
# clip to the valid range
|
72 |
+
ind = np.clip(ind, 0, vox.dims[0] - 1)
|
73 |
+
# new_data[ind[:, 0], ind[:, 1], ind[:, 2]] = True
|
74 |
+
return ind
|
75 |
+
|
76 |
+
def get_skin_direction(joint_idx, data, parent_index, joints_matrix):
|
77 |
+
# Get points influenced by this joint (weight > 0)
|
78 |
+
weights = index_to_sparse(data['skins_index'].unsqueeze(0), data['skins_weight'].unsqueeze(0), [1, 8192, data['bones_num']])[0][:,joint_idx]
|
79 |
+
mask = weights > 0
|
80 |
+
|
81 |
+
if not torch.any(mask):
|
82 |
+
# If no points are influenced, return the opposite direction of its parent
|
83 |
+
parent_idx = parent_index[joint_idx].item()
|
84 |
+
if parent_idx == joint_idx:
|
85 |
+
return torch.tensor([0, 0, 0.001])
|
86 |
+
parent_pos = joints_matrix[parent_idx, :3]
|
87 |
+
joint_pos = joints_matrix[joint_idx, :3]
|
88 |
+
direction = joint_pos - parent_pos
|
89 |
+
norm = torch.norm(direction)
|
90 |
+
if norm < 1e-8: # Add check for zero norm
|
91 |
+
return torch.tensor([0, 0, 0.001])
|
92 |
+
normalized_direction = direction / norm
|
93 |
+
return normalized_direction * 0.01
|
94 |
+
|
95 |
+
# Get joint position
|
96 |
+
joint_pos = joints_matrix[joint_idx, :3]
|
97 |
+
|
98 |
+
# Get weighted average direction from joint to influenced points
|
99 |
+
points = data['pc'][mask][:,:3]
|
100 |
+
point_weights = weights[mask]
|
101 |
+
|
102 |
+
# Calculate directions from joint to each point
|
103 |
+
directions = points - joint_pos
|
104 |
+
|
105 |
+
# Calculate weighted average direction
|
106 |
+
avg_direction = torch.sum(directions * point_weights.unsqueeze(1), dim=0) / torch.sum(point_weights)
|
107 |
+
if torch.norm(avg_direction) < 1e-5:
|
108 |
+
return torch.tensor([0, 0, 0.001])
|
109 |
+
return avg_direction * 1.25
|
110 |
+
|
111 |
+
def obj2mesh(obj_path):
|
112 |
+
# open the obj as txt
|
113 |
+
vertices = []
|
114 |
+
faces = []
|
115 |
+
with open(obj_path, 'r') as f:
|
116 |
+
obj = f.readlines()
|
117 |
+
for line in obj:
|
118 |
+
if line.startswith('v '):
|
119 |
+
vertices.append(list(map(float, line.split()[1:])))
|
120 |
+
elif line.startswith('f '):
|
121 |
+
faces.append(list(map(int, [i.split('/')[0] for i in line.split()[1:]])))
|
122 |
+
vertices = np.array(vertices)
|
123 |
+
faces = np.array(faces) - 1
|
124 |
+
# print(vertices.shape, faces.shape)
|
125 |
+
|
126 |
+
# create trimesh mesh with given vertices and faces
|
127 |
+
mesh = trimesh.Trimesh(vertices, faces, process=False)
|
128 |
+
# print(mesh.vertices.shape, mesh.faces.shape)
|
129 |
+
return mesh
|
Anymate/utils/diffusion_encoder.py
ADDED
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from typing import Optional
|
4 |
+
from einops import repeat
|
5 |
+
import math
|
6 |
+
from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock,Transformer, checkpoint
|
7 |
+
from torch.nn import Sequential, Dropout, Linear, ReLU, Parameter, BatchNorm1d
|
8 |
+
from typing import List, Optional, Tuple, Union
|
9 |
+
|
10 |
+
class ShapeAsLatentModule(nn.Module):
|
11 |
+
latent_shape: Tuple[int, int]
|
12 |
+
|
13 |
+
def __init__(self, *args, **kwargs):
|
14 |
+
super().__init__()
|
15 |
+
|
16 |
+
def encode(self, *args, **kwargs):
|
17 |
+
raise NotImplementedError
|
18 |
+
|
19 |
+
def decode(self, *args, **kwargs):
|
20 |
+
raise NotImplementedError
|
21 |
+
|
22 |
+
def query_geometry(self, *args, **kwargs):
|
23 |
+
raise NotImplementedError
|
24 |
+
|
25 |
+
class FourierEmbedder(nn.Module):
|
26 |
+
|
27 |
+
def __init__(self,
|
28 |
+
num_freqs: int = 6,
|
29 |
+
logspace: bool = True,
|
30 |
+
input_dim: int = 3,
|
31 |
+
include_input: bool = True,
|
32 |
+
include_pi: bool = True) -> None:
|
33 |
+
|
34 |
+
"""The initialization"""
|
35 |
+
|
36 |
+
super().__init__()
|
37 |
+
|
38 |
+
if logspace:
|
39 |
+
frequencies = 2.0 ** torch.arange(
|
40 |
+
num_freqs,
|
41 |
+
dtype=torch.float32
|
42 |
+
)
|
43 |
+
else:
|
44 |
+
frequencies = torch.linspace(
|
45 |
+
1.0,
|
46 |
+
2.0 ** (num_freqs - 1),
|
47 |
+
num_freqs,
|
48 |
+
dtype=torch.float32
|
49 |
+
)
|
50 |
+
|
51 |
+
if include_pi:
|
52 |
+
frequencies *= torch.pi
|
53 |
+
|
54 |
+
self.register_buffer("frequencies", frequencies, persistent=False)
|
55 |
+
self.include_input = include_input
|
56 |
+
self.num_freqs = num_freqs
|
57 |
+
|
58 |
+
self.out_dim = self.get_dims(input_dim)
|
59 |
+
|
60 |
+
def get_dims(self, input_dim):
|
61 |
+
temp = 1 if self.include_input or self.num_freqs == 0 else 0
|
62 |
+
out_dim = input_dim * (self.num_freqs * 2 + temp)
|
63 |
+
|
64 |
+
return out_dim
|
65 |
+
|
66 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
67 |
+
|
68 |
+
if self.num_freqs > 0:
|
69 |
+
self.frequencies = self.frequencies.to(x.device)
|
70 |
+
embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
|
71 |
+
|
72 |
+
if self.include_input:
|
73 |
+
return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
|
74 |
+
else:
|
75 |
+
return torch.cat((embed.sin(), embed.cos()), dim=-1)
|
76 |
+
else:
|
77 |
+
return x
|
78 |
+
|
79 |
+
def MLP(channels, batch_norm=True):
|
80 |
+
if batch_norm:
|
81 |
+
return Sequential(*[Sequential(Linear(channels[i - 1], channels[i]), ReLU(), BatchNorm1d(channels[i], momentum=0.1))
|
82 |
+
for i in range(1, len(channels))])
|
83 |
+
else:
|
84 |
+
return Sequential(*[Sequential(Linear(channels[i - 1], channels[i]), ReLU()) for i in range(1, len(channels))])
|
85 |
+
|
86 |
+
class CrossAttentionEncoder(nn.Module):
|
87 |
+
|
88 |
+
def __init__(self, *,
|
89 |
+
device: Optional[torch.device],
|
90 |
+
dtype: Optional[torch.dtype],
|
91 |
+
num_latents: int,
|
92 |
+
fourier_embedder: FourierEmbedder,
|
93 |
+
point_feats: int,
|
94 |
+
width: int,
|
95 |
+
heads: int,
|
96 |
+
layers: int,
|
97 |
+
init_scale: float = 0.25,
|
98 |
+
qkv_bias: bool = True,
|
99 |
+
flash: bool = False,
|
100 |
+
use_ln_post: bool = False,
|
101 |
+
use_checkpoint: bool = False):
|
102 |
+
|
103 |
+
super().__init__()
|
104 |
+
|
105 |
+
self.use_checkpoint = use_checkpoint
|
106 |
+
self.num_latents = num_latents
|
107 |
+
self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
|
108 |
+
|
109 |
+
self.fourier_embedder = fourier_embedder
|
110 |
+
self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
|
111 |
+
self.cross_attn = ResidualCrossAttentionBlock(
|
112 |
+
device=device,
|
113 |
+
dtype=dtype,
|
114 |
+
width=width,
|
115 |
+
heads=heads,
|
116 |
+
init_scale=init_scale,
|
117 |
+
qkv_bias=qkv_bias,
|
118 |
+
flash=flash,
|
119 |
+
)
|
120 |
+
|
121 |
+
self.self_attn = Transformer(
|
122 |
+
device=device,
|
123 |
+
dtype=dtype,
|
124 |
+
n_ctx=num_latents,
|
125 |
+
width=width,
|
126 |
+
layers=layers,
|
127 |
+
heads=heads,
|
128 |
+
init_scale=init_scale,
|
129 |
+
qkv_bias=qkv_bias,
|
130 |
+
flash=flash,
|
131 |
+
use_checkpoint=False
|
132 |
+
)
|
133 |
+
|
134 |
+
if use_ln_post:
|
135 |
+
self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
|
136 |
+
else:
|
137 |
+
self.ln_post = None
|
138 |
+
|
139 |
+
def _forward(self, pc, feats):
|
140 |
+
"""
|
141 |
+
|
142 |
+
Args:
|
143 |
+
pc (torch.FloatTensor): [B, N, 3]
|
144 |
+
feats (torch.FloatTensor or None): [B, N, C]
|
145 |
+
|
146 |
+
Returns:
|
147 |
+
|
148 |
+
"""
|
149 |
+
|
150 |
+
bs = pc.shape[0]
|
151 |
+
|
152 |
+
data = self.fourier_embedder(pc)
|
153 |
+
if feats is not None:
|
154 |
+
data = torch.cat([data, feats], dim=-1)
|
155 |
+
data = self.input_proj(data)
|
156 |
+
|
157 |
+
query = repeat(self.query, "m c -> b m c", b=bs)
|
158 |
+
latents = self.cross_attn(query, data)
|
159 |
+
latents = self.self_attn(latents)
|
160 |
+
|
161 |
+
if self.ln_post is not None:
|
162 |
+
latents = self.ln_post(latents)
|
163 |
+
|
164 |
+
return latents, pc
|
165 |
+
|
166 |
+
def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
|
167 |
+
"""
|
168 |
+
|
169 |
+
Args:
|
170 |
+
pc (torch.FloatTensor): [B, N, 3]
|
171 |
+
feats (torch.FloatTensor or None): [B, N, C]
|
172 |
+
|
173 |
+
Returns:
|
174 |
+
dict
|
175 |
+
"""
|
176 |
+
|
177 |
+
return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
|
178 |
+
|
179 |
+
|
180 |
+
|
181 |
+
class TransformerEncoder(ShapeAsLatentModule):
|
182 |
+
def __init__(self, *,
|
183 |
+
device: Optional[torch.device]='cuda',
|
184 |
+
dtype: Optional[torch.dtype],
|
185 |
+
num_latents: int = 16,
|
186 |
+
point_feats: int = 3,
|
187 |
+
embed_dim: int = 64,
|
188 |
+
num_freqs: int = 8,
|
189 |
+
include_pi: bool = True,
|
190 |
+
width: int = 768,
|
191 |
+
heads: int = 12,
|
192 |
+
num_encoder_layers: int = 8,
|
193 |
+
init_scale: float = 0.25,
|
194 |
+
qkv_bias: bool = True,
|
195 |
+
flash: bool = False,
|
196 |
+
use_ln_post: bool = False,
|
197 |
+
use_checkpoint: bool = False,
|
198 |
+
out_channels: int = 4):
|
199 |
+
|
200 |
+
super().__init__()
|
201 |
+
|
202 |
+
self.use_checkpoint = use_checkpoint
|
203 |
+
|
204 |
+
self.num_latents = num_latents
|
205 |
+
self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
|
206 |
+
|
207 |
+
init_scale = init_scale * math.sqrt(1.0 / width)
|
208 |
+
self.encoder = CrossAttentionEncoder(
|
209 |
+
device=device,
|
210 |
+
dtype=dtype,
|
211 |
+
fourier_embedder=self.fourier_embedder,
|
212 |
+
num_latents=num_latents,
|
213 |
+
point_feats=point_feats,
|
214 |
+
width=width,
|
215 |
+
heads=heads,
|
216 |
+
layers=num_encoder_layers,
|
217 |
+
init_scale=init_scale,
|
218 |
+
qkv_bias=qkv_bias,
|
219 |
+
flash=flash,
|
220 |
+
use_ln_post=use_ln_post,
|
221 |
+
use_checkpoint=use_checkpoint
|
222 |
+
)
|
223 |
+
self.width = width
|
224 |
+
self.out_channels = out_channels
|
225 |
+
self.device = device
|
226 |
+
|
227 |
+
self.embed_dim = embed_dim
|
228 |
+
|
229 |
+
def encode(self,data):
|
230 |
+
input_points = data['points_cloud'].to(self.device)
|
231 |
+
bs = input_points.shape[0]
|
232 |
+
pc, feats = input_points[...,:3], input_points[..., 3:]
|
233 |
+
latents, _ = self.encoder(pc, feats)
|
234 |
+
# print_time('after encoder')
|
235 |
+
latents = latents.reshape(bs,-1, self.width)
|
236 |
+
return latents
|
237 |
+
def encode_pc(self,points_cloud):
|
238 |
+
bs = points_cloud.shape[0]
|
239 |
+
input_points = points_cloud.to(self.device)
|
240 |
+
pc, feats = input_points[...,:3], input_points[..., 3:]
|
241 |
+
latents, _ = self.encoder(pc, feats)
|
242 |
+
|
243 |
+
latents = latents.reshape(bs,-1, self.width)
|
244 |
+
return latents
|
245 |
+
def forward(self, data):
|
246 |
+
|
247 |
+
# input_points = torch.from_numpy(np.array(data.points_cloud)).cuda()
|
248 |
+
input_points = data['points_cloud'].to(self.device)
|
249 |
+
pc, feats = input_points[...,:3], input_points[..., 3:]
|
250 |
+
latents, _ = self.encoder(pc, feats)
|
251 |
+
|
252 |
+
latents = latents.reshape(-1, self.width)
|
253 |
+
latents =latents.reshape(-1, self.num_latents, self.out_channels)
|
254 |
+
latents[..., :3] = torch.tanh(latents[..., :3])
|
255 |
+
latents[..., 3:] = torch.sigmoid(latents[..., 3:])
|
256 |
+
|
257 |
+
|
258 |
+
return latents
|
Anymate/utils/diffusion_utils.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
|
2 |
+
import numpy as np
|
3 |
+
import matplotlib.pyplot as plt
|
4 |
+
from mpl_toolkits.mplot3d import Axes3D
|
5 |
+
from torchvision.utils import make_grid
|
6 |
+
import torch
|
7 |
+
from typing import List, Optional, Tuple, Union
|
8 |
+
import torch.nn as nn
|
9 |
+
import math
|
10 |
+
from timm.models.vision_transformer import Mlp, DropPath
|
11 |
+
|
12 |
+
def my_collate_diff(batch,return_joints_num=128,random=False):
|
13 |
+
data = {}
|
14 |
+
for key in batch[0]:
|
15 |
+
if key=='vox' or key=='name' or key=='joints_num' or key=='skins_index' or key=='skins_weight' or key=='parent_index' or key=='conns' or key=='joints' or key=='bones' or key=='mesh_skins_index' or key=='mesh_skins_weight' or key=='mesh_pc' or key=='mesh_face':
|
16 |
+
data[key] = [sample[key] for sample in batch]
|
17 |
+
elif key=='pc':
|
18 |
+
data['points_cloud'] = torch.stack([sample['pc'] for sample in batch])
|
19 |
+
elif key=='skins':
|
20 |
+
continue
|
21 |
+
elif key=='bones_num':
|
22 |
+
data[key] = torch.tensor([sample['bones_num'] for sample in batch])
|
23 |
+
else:
|
24 |
+
data[key] = torch.stack([sample[key] for sample in batch])
|
25 |
+
|
26 |
+
if 'joints' in batch[0]:
|
27 |
+
padded_joints_matrix = torch.ones(len(data['name']), return_joints_num, 3) * (-3)
|
28 |
+
joints_matrix = torch.ones(len(data['name']), 96, 3) * (-3)
|
29 |
+
for i in range(len(data['name'])):
|
30 |
+
joints_matrix[i, :data['joints_num'][i], :] = data['joints'][i]
|
31 |
+
if not random:
|
32 |
+
for i in range(len(data['name'])):
|
33 |
+
padded_joints_matrix[i] = data['joints'][i].repeat(return_joints_num//data['joints_num'][i]+1,1)[:return_joints_num,:]
|
34 |
+
else:
|
35 |
+
for i in range(len(data['name'])):
|
36 |
+
padded_joints_matrix[i] = data['joints'][i][torch.randint(0, data['joints_num'][i], (return_joints_num,))]
|
37 |
+
data['joints_repeat'] = padded_joints_matrix
|
38 |
+
data['joints'] = joints_matrix
|
39 |
+
|
40 |
+
return data
|
41 |
+
|
42 |
+
def randn_tensor(
|
43 |
+
shape: Union[Tuple, List],
|
44 |
+
generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
|
45 |
+
device: Optional["torch.device"] = None,
|
46 |
+
dtype: Optional["torch.dtype"] = None,
|
47 |
+
layout: Optional["torch.layout"] = None,
|
48 |
+
):
|
49 |
+
"""A helper function to create random tensors on the desired `device` with the desired `dtype`. When
|
50 |
+
passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
|
51 |
+
is always created on the CPU.
|
52 |
+
"""
|
53 |
+
# device on which tensor is created defaults to device
|
54 |
+
rand_device = device
|
55 |
+
batch_size = shape[0]
|
56 |
+
|
57 |
+
layout = layout or torch.strided
|
58 |
+
device = device or torch.device("cpu")
|
59 |
+
|
60 |
+
if generator is not None:
|
61 |
+
gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
|
62 |
+
if gen_device_type != device.type and gen_device_type == "cpu":
|
63 |
+
rand_device = "cpu"
|
64 |
+
|
65 |
+
elif gen_device_type != device.type and gen_device_type == "cuda":
|
66 |
+
raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
|
67 |
+
|
68 |
+
# make sure generator list of length 1 is treated like a non-list
|
69 |
+
if isinstance(generator, list) and len(generator) == 1:
|
70 |
+
generator = generator[0]
|
71 |
+
|
72 |
+
if isinstance(generator, list):
|
73 |
+
shape = (1,) + shape[1:]
|
74 |
+
latents = [
|
75 |
+
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
|
76 |
+
for i in range(batch_size)
|
77 |
+
]
|
78 |
+
latents = torch.cat(latents, dim=0).to(device)
|
79 |
+
else:
|
80 |
+
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
|
81 |
+
|
82 |
+
return latents
|
83 |
+
|
84 |
+
def timestep_embedding(timesteps, dim, max_period=10000):
|
85 |
+
"""
|
86 |
+
Create sinusoidal timestep embeddings.
|
87 |
+
:param timesteps: a 1-D Tensor of N indices, one per batch element.
|
88 |
+
These may be fractional.
|
89 |
+
:param dim: the dimension of the output.
|
90 |
+
:param max_period: controls the minimum frequency of the embeddings.
|
91 |
+
:return: an [N x dim] Tensor of positional embeddings.
|
92 |
+
"""
|
93 |
+
half = dim // 2
|
94 |
+
freqs = torch.exp(
|
95 |
+
-math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
|
96 |
+
).to(device=timesteps.device)
|
97 |
+
args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
|
98 |
+
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
|
99 |
+
if dim % 2:
|
100 |
+
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
|
101 |
+
return embedding
|
102 |
+
|
103 |
+
class CrossAttention(nn.Module):
|
104 |
+
def __init__(
|
105 |
+
self,
|
106 |
+
dim,
|
107 |
+
kv_dim=None,
|
108 |
+
num_heads=16,
|
109 |
+
qkv_bias=False,
|
110 |
+
attn_drop=0.,
|
111 |
+
proj_drop=0.,
|
112 |
+
):
|
113 |
+
super().__init__()
|
114 |
+
self.num_heads = num_heads
|
115 |
+
head_dim = dim // num_heads
|
116 |
+
self.scale = head_dim ** -0.5
|
117 |
+
|
118 |
+
kv_dim = dim if not kv_dim else kv_dim
|
119 |
+
self.wq = nn.Linear(dim, dim, bias=qkv_bias)
|
120 |
+
self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
|
121 |
+
self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
|
122 |
+
self.attn_drop_rate = attn_drop
|
123 |
+
self.attn_drop = nn.Dropout(self.attn_drop_rate)
|
124 |
+
self.proj = nn.Linear(dim, dim)
|
125 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
126 |
+
|
127 |
+
def forward(self, x_q, x_kv):
|
128 |
+
B, N_q, C = x_q.shape
|
129 |
+
B, N_kv, _ = x_kv.shape
|
130 |
+
# [B, N_q, C] -> [B, N_q, H, C/H] -> [B, H, N_q, C/H]
|
131 |
+
q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
132 |
+
# [B, N_kv, C] -> [B, N_kv, H, C/H] -> [B, H, N_kv, C/H]
|
133 |
+
k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
134 |
+
# [B, N_kv, C] -> [B, N_kv, H, C/H] -> [B, H, N_kv, C/H]
|
135 |
+
v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
136 |
+
|
137 |
+
# [B, H, N_q, C/H] @ [B, H, C/H, N_kv] -> [B, H, N_q, N_kv]
|
138 |
+
attn = (q @ k.transpose(-2, -1)) * self.scale
|
139 |
+
attn = attn.softmax(dim=-1)
|
140 |
+
attn = self.attn_drop(attn)
|
141 |
+
|
142 |
+
# [B, H, N_q, N_kv] @ [B, H, N_kv, C/H] -> [B, H, N_q, C/H]
|
143 |
+
x = attn @ v
|
144 |
+
|
145 |
+
# [B, H, N_q, C/H] -> [B, N_q, C]
|
146 |
+
x = x.transpose(1, 2).reshape(B, N_q, C)
|
147 |
+
x = self.proj(x)
|
148 |
+
x = self.proj_drop(x)
|
149 |
+
return x
|
150 |
+
|
151 |
+
class Compute_Block(nn.Module):
|
152 |
+
|
153 |
+
def __init__(self, z_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
154 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
155 |
+
super().__init__()
|
156 |
+
self.norm_z1 = norm_layer(z_dim)
|
157 |
+
self.attn = CrossAttention(
|
158 |
+
z_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
159 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
160 |
+
self.norm_z2 = norm_layer(z_dim)
|
161 |
+
mlp_hidden_dim = int(z_dim * mlp_ratio)
|
162 |
+
self.mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
163 |
+
|
164 |
+
def forward(self, z):
|
165 |
+
zn = self.norm_z1(z)
|
166 |
+
z = z + self.drop_path(self.attn(zn, zn))
|
167 |
+
z = z + self.drop_path(self.mlp(self.norm_z2(z)))
|
168 |
+
return z
|
169 |
+
|
170 |
+
class Read_Block(nn.Module):
|
171 |
+
|
172 |
+
def __init__(self, z_dim, x_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
173 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
174 |
+
super().__init__()
|
175 |
+
self.norm_x = norm_layer(x_dim)
|
176 |
+
self.norm_z1 = norm_layer(z_dim)
|
177 |
+
self.attn = CrossAttention(
|
178 |
+
z_dim, x_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
179 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
180 |
+
self.norm_z2 = norm_layer(z_dim)
|
181 |
+
mlp_hidden_dim = int(z_dim * mlp_ratio)
|
182 |
+
self.mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
183 |
+
|
184 |
+
def forward(self, z, x):
|
185 |
+
z = z + self.drop_path(self.attn(self.norm_z1(z), self.norm_x(x)))
|
186 |
+
z = z + self.drop_path(self.mlp(self.norm_z2(z)))
|
187 |
+
return z
|
188 |
+
|
189 |
+
class Write_Block(nn.Module):
|
190 |
+
|
191 |
+
def __init__(self, z_dim, x_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
192 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
193 |
+
super().__init__()
|
194 |
+
self.norm_z = norm_layer(z_dim)
|
195 |
+
self.norm_x1 = norm_layer(x_dim)
|
196 |
+
self.attn = CrossAttention(
|
197 |
+
x_dim, z_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
|
198 |
+
self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
|
199 |
+
self.norm_x2 = norm_layer(x_dim)
|
200 |
+
mlp_hidden_dim = int(x_dim * mlp_ratio)
|
201 |
+
self.mlp = Mlp(in_features=x_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
|
202 |
+
|
203 |
+
def forward(self, z, x):
|
204 |
+
x = x + self.drop_path(self.attn(self.norm_x1(x), self.norm_z(z)))
|
205 |
+
x = x + self.drop_path(self.mlp(self.norm_x2(x)))
|
206 |
+
return x
|
207 |
+
|
208 |
+
class RCW_Block(nn.Module):
|
209 |
+
|
210 |
+
def __init__(self, z_dim, x_dim, num_compute_layers=4, num_heads=16,
|
211 |
+
mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
|
212 |
+
drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
|
213 |
+
super().__init__()
|
214 |
+
self.read = Read_Block(z_dim, x_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
|
215 |
+
attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer)
|
216 |
+
self.write = Write_Block(z_dim, x_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
|
217 |
+
attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer)
|
218 |
+
self.compute = nn.ModuleList([
|
219 |
+
Compute_Block(z_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
|
220 |
+
attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer)
|
221 |
+
for _ in range(num_compute_layers)
|
222 |
+
])
|
223 |
+
|
224 |
+
def forward(self, z, x):
|
225 |
+
z = self.read(z, x)
|
226 |
+
for layer in self.compute:
|
227 |
+
z = layer(z)
|
228 |
+
x = self.write(z, x)
|
229 |
+
return z, x
|
230 |
+
|
231 |
+
def pairwise_distances(x, y):
|
232 |
+
#Input: x is a Nxd matrix
|
233 |
+
# y is an optional Mxd matirx
|
234 |
+
#Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
|
235 |
+
# if y is not given then use 'y=x'.
|
236 |
+
#i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
|
237 |
+
x_norm = (x ** 2).sum(1).view(-1, 1)
|
238 |
+
y_t = torch.transpose(y, 0, 1)
|
239 |
+
y_norm = (y ** 2).sum(1).view(1, -1)
|
240 |
+
dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
|
241 |
+
return torch.clamp(dist, 0.0, np.inf)
|
242 |
+
|
243 |
+
def meanshift_cluster(pts_in, bandwidth, weights=None, max_iter=20):
|
244 |
+
"""
|
245 |
+
Meanshift clustering
|
246 |
+
:param pts_in: input points
|
247 |
+
:param bandwidth: bandwidth
|
248 |
+
:param weights: weights per pts indicting its importance in the clustering
|
249 |
+
:return: points after clustering
|
250 |
+
"""
|
251 |
+
diff = 1e10
|
252 |
+
num_iter = 1
|
253 |
+
while diff > 1e-3 and num_iter < max_iter:
|
254 |
+
Y = np.sum(((pts_in[np.newaxis, ...] - pts_in[:, np.newaxis, :]) ** 2), axis=2)
|
255 |
+
K = np.maximum(bandwidth**2 - Y, np.zeros(Y.shape))
|
256 |
+
if weights is not None:
|
257 |
+
K = K * weights
|
258 |
+
row_sums = K.sum(axis=0, keepdims=True)
|
259 |
+
P = K / (row_sums + 1e-10)
|
260 |
+
P = P.transpose()
|
261 |
+
pts_in_prim = 0.3 * (np.matmul(P, pts_in) - pts_in) + pts_in
|
262 |
+
diff = np.sqrt(np.sum((pts_in_prim - pts_in)**2))
|
263 |
+
pts_in = pts_in_prim
|
264 |
+
num_iter += 1
|
265 |
+
return pts_in
|
266 |
+
|
267 |
+
def nms_meanshift(pts_in, density, bandwidth):
|
268 |
+
"""
|
269 |
+
NMS to extract modes after meanshift. Code refers to sci-kit-learn.
|
270 |
+
:param pts_in: input points
|
271 |
+
:param density: density at each point
|
272 |
+
:param bandwidth: bandwidth used in meanshift. Used here as neighbor region for NMS
|
273 |
+
:return: extracted clusters.
|
274 |
+
"""
|
275 |
+
Y = np.sum(((pts_in[np.newaxis, ...] - pts_in[:, np.newaxis, :]) ** 2), axis=2)
|
276 |
+
sorted_ids = np.argsort(density)[::-1]
|
277 |
+
unique = np.ones(len(sorted_ids), dtype=bool)
|
278 |
+
dist = np.sqrt(Y)
|
279 |
+
for i in sorted_ids:
|
280 |
+
if unique[i]:
|
281 |
+
neighbor_idxs = np.argwhere(dist[:, i] <= bandwidth)
|
282 |
+
unique[neighbor_idxs.squeeze()] = 0
|
283 |
+
unique[i] = 1 # leave the current point as unique
|
284 |
+
pts_in = pts_in[unique]
|
285 |
+
return pts_in
|
286 |
+
|
287 |
+
def get_predictions(y_pred_np, attn_pred_np=None,bandwidth=0.05, threshold=0.001):
|
288 |
+
"""
|
289 |
+
get the final predictions
|
290 |
+
:param pts: input points
|
291 |
+
:param weights: weight per point during clustering
|
292 |
+
:return: clustered points
|
293 |
+
"""
|
294 |
+
# if attn_pred_np is None:
|
295 |
+
# attn_pred_np = np.ones(y_pred_np.shape[0])
|
296 |
+
y_pred_np = meanshift_cluster(y_pred_np, bandwidth, attn_pred_np, max_iter=40)
|
297 |
+
|
298 |
+
|
299 |
+
Y_dist = np.sum(((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :]) ** 2), axis=2)
|
300 |
+
density = np.maximum(bandwidth ** 2 - Y_dist, np.zeros(Y_dist.shape))
|
301 |
+
density = np.sum(density, axis=0)
|
302 |
+
density_sum = np.sum(density)
|
303 |
+
y_pred_np = y_pred_np[density / density_sum > threshold]
|
304 |
+
|
305 |
+
density = density[density / density_sum > threshold]
|
306 |
+
pred_joints = nms_meanshift(y_pred_np, density, bandwidth)
|
307 |
+
return pred_joints
|
308 |
+
|
309 |
+
|
310 |
+
if __name__ == '__main__':
|
311 |
+
points_cloud = np.ones((100, 3))
|
312 |
+
predict_out = get_predictions(points_cloud, bandwidth=0.05, threshold=0.001)
|
313 |
+
print(predict_out.shape)
|
314 |
+
|
Anymate/utils/eval_utils.py
ADDED
@@ -0,0 +1,225 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from tqdm import tqdm
|
2 |
+
import torch
|
3 |
+
import torch.nn.functional as F
|
4 |
+
import numpy as np
|
5 |
+
import point_cloud_utils as pcu
|
6 |
+
from Anymate.utils.loss_utils import chamfer_distance_with_average, cross_entropy_with_probs_batch, cos_loss, cos_loss_clamp
|
7 |
+
from ThirdParty.Rignet_utils.utils import get_skel
|
8 |
+
from ThirdParty.Rignet_utils.Rignet_loss import edit_dist, chamfer_dist, joint2bone_chamfer_dist, bone2bone_chamfer_dist
|
9 |
+
from scipy.optimize import linear_sum_assignment
|
10 |
+
|
11 |
+
def evaluate_joint(joints, joints_gt, threshold=1e-1):
|
12 |
+
"""
|
13 |
+
joints: list of predicted joints: tensor of shape (n,joints_num,3)
|
14 |
+
joints_gt: list of ground truth joints : tensor of shape (n,joints_num,3)
|
15 |
+
"""
|
16 |
+
chamfer_loss_all = 0
|
17 |
+
emd_loss_all = 0
|
18 |
+
precision = 0
|
19 |
+
recall = 0
|
20 |
+
count = 0
|
21 |
+
|
22 |
+
for i in tqdm(range(len(joints))):
|
23 |
+
joint_predict = joints[i].cpu()
|
24 |
+
joint_gt = joints_gt[i].cpu()
|
25 |
+
distance_matrix = torch.cdist(joint_gt, joint_predict) # (n_gt, n_predict)
|
26 |
+
n_gt,n_predict = distance_matrix.shape
|
27 |
+
min_distance_pred = torch.min(distance_matrix, dim=0)
|
28 |
+
min_distance_gt = torch.min(distance_matrix, dim=1)
|
29 |
+
precision += torch.sum(min_distance_pred.values < threshold).item()/n_predict
|
30 |
+
recall += torch.sum(min_distance_gt.values < threshold).item()/n_gt
|
31 |
+
|
32 |
+
chamfer_loss_all += chamfer_distance_with_average(joint_predict.unsqueeze(0), joint_gt.unsqueeze(0))
|
33 |
+
joint_predict = joint_predict.numpy().astype(np.float64)
|
34 |
+
joint_gt = joint_gt.numpy().astype(np.float64)
|
35 |
+
emd,_ = pcu.earth_movers_distance(joint_predict, joint_gt)
|
36 |
+
emd_loss_all += emd
|
37 |
+
|
38 |
+
count += 1
|
39 |
+
|
40 |
+
print('------------------------------------')
|
41 |
+
print('Evaluation results for joint:')
|
42 |
+
print('chamfer_loss:', chamfer_loss_all/count)
|
43 |
+
print('emd_loss:', emd_loss_all/count)
|
44 |
+
print('precision:', precision/count)
|
45 |
+
print('recall:', recall/count)
|
46 |
+
print('count:', count)
|
47 |
+
print('------------------------------------')
|
48 |
+
return chamfer_loss_all/count, emd_loss_all/count, precision/count, recall/count
|
49 |
+
|
50 |
+
def evaluate_connectivity(conns, conns_gt, joints_gt, vox_list):
|
51 |
+
|
52 |
+
"""
|
53 |
+
conns: list of predicted connections probability: tensor of shape (n,joints_num,joints_num)
|
54 |
+
conns_gt: list of ground truth connections: tensor of shape (n,joints_num,joints_num)
|
55 |
+
"""
|
56 |
+
|
57 |
+
precision_all = 0
|
58 |
+
recall_all = 0
|
59 |
+
cross_entropy_all = 0
|
60 |
+
bone2bone_dist_con = 0
|
61 |
+
count = 0
|
62 |
+
for i in tqdm(range(len(conns))):
|
63 |
+
|
64 |
+
conn_predict = conns[i].cpu().numpy()
|
65 |
+
conn_gt = conns_gt[i].cpu().numpy()
|
66 |
+
joints = joints_gt[i].cpu().numpy()
|
67 |
+
vox = vox_list[i]
|
68 |
+
|
69 |
+
cross_entropy_all += cross_entropy_with_probs_batch(torch.from_numpy(conn_predict).unsqueeze(0), torch.from_numpy(conn_gt).unsqueeze(0), reduction='mean')
|
70 |
+
# consider to add tree edit distance (need joint and vox information)
|
71 |
+
pred_skel, parent_matrix = get_skel(joints, conn_predict, vox=vox)
|
72 |
+
gt_skel, parent_matrix = get_skel(joints, conn_gt, vox=vox)
|
73 |
+
bone2bone_dist_con += bone2bone_chamfer_dist(pred_skel, gt_skel)
|
74 |
+
|
75 |
+
conn_predict = np.argmax(conn_predict, axis=1)
|
76 |
+
conn_gt = np.argmax(conn_gt, axis=1)
|
77 |
+
connection_matrix_pre = torch.zeros((len(conn_predict),len(conn_predict)))
|
78 |
+
connection_matrix_gt = torch.zeros((len(conn_predict),len(conn_predict)))
|
79 |
+
|
80 |
+
for i in range(len(conn_predict)):
|
81 |
+
connection_matrix_pre[i][conn_predict[i]] = 1
|
82 |
+
connection_matrix_pre[conn_predict[i]][i] = 1
|
83 |
+
connection_matrix_gt[i][conn_gt[i]] = 1
|
84 |
+
connection_matrix_gt[conn_gt[i]][i] = 1
|
85 |
+
|
86 |
+
TP = 0
|
87 |
+
FP = 0
|
88 |
+
FN = 0
|
89 |
+
FP = 0
|
90 |
+
|
91 |
+
for i in range(len(conn_predict)):
|
92 |
+
if connection_matrix_gt[i][conn_predict[i]] == 1:
|
93 |
+
TP += 1
|
94 |
+
if connection_matrix_gt[i][conn_predict[i]] == 0:
|
95 |
+
FP += 1
|
96 |
+
if connection_matrix_pre[i][conn_gt[i]] == 0:
|
97 |
+
FN += 1
|
98 |
+
|
99 |
+
precision = TP/(TP+FP)
|
100 |
+
recall = TP/(TP+FN)
|
101 |
+
|
102 |
+
precision_all += precision
|
103 |
+
recall_all += recall
|
104 |
+
count+=1
|
105 |
+
print('------------------------------------')
|
106 |
+
print('Evaluation results for connectivity:')
|
107 |
+
print('precision:',precision_all/count)
|
108 |
+
print('recall:',recall_all/count)
|
109 |
+
print('cross_entropy:',cross_entropy_all/count)
|
110 |
+
print('bone2bone_dist_con:',bone2bone_dist_con/count)
|
111 |
+
print('count:',count)
|
112 |
+
print('------------------------------------')
|
113 |
+
return precision_all/count, recall_all/count
|
114 |
+
|
115 |
+
def evaluate_skinning(skins, skins_gt, threshold=5e-2):
|
116 |
+
"""
|
117 |
+
skins: list of predicted skinning weights: tensor of shape (n,vertices_num, bones_num)
|
118 |
+
skins_gt: list of ground truth skinning weights: tensor of shape (n,vertices_num, bones_num)
|
119 |
+
"""
|
120 |
+
cs_loss = 0
|
121 |
+
ce_loss = 0
|
122 |
+
cs_loss_clamp = 0
|
123 |
+
count = 0
|
124 |
+
L1_loss = 0
|
125 |
+
precision = 0
|
126 |
+
recall = 0
|
127 |
+
mean_l1_dist = 0
|
128 |
+
|
129 |
+
for i in tqdm(range(len(skins))):
|
130 |
+
skin_predict = skins[i].cpu().unsqueeze(0)
|
131 |
+
skin_gt = skins_gt[i].cpu().unsqueeze(0)
|
132 |
+
|
133 |
+
precision_one = 0
|
134 |
+
recall_one = 0
|
135 |
+
|
136 |
+
ce_loss += cross_entropy_with_probs_batch(skin_predict, skin_gt)
|
137 |
+
cs_loss += cos_loss(skin_predict, skin_gt)
|
138 |
+
cs_loss_clamp += cos_loss_clamp(skin_predict, skin_gt)
|
139 |
+
L1_loss += F.l1_loss(skin_predict, skin_gt)
|
140 |
+
skin_predict = skin_predict[0].cpu().detach().numpy()
|
141 |
+
skin_gt = skin_gt[0].cpu().detach().numpy()
|
142 |
+
mean_l1_dist += np.sum(np.abs(skin_predict - skin_gt )) / len(skin_predict)
|
143 |
+
|
144 |
+
for i in range(len(skin_predict)):
|
145 |
+
influencial_bone_predict = skin_predict[i] >=threshold
|
146 |
+
influencial_bone_gt = skin_gt[i] >=threshold
|
147 |
+
influencial_bone_correct = influencial_bone_predict*influencial_bone_gt
|
148 |
+
|
149 |
+
if np.sum(influencial_bone_predict)==0 or np.sum(influencial_bone_gt)==0:
|
150 |
+
continue
|
151 |
+
precision_one += np.sum(influencial_bone_correct)/np.sum(influencial_bone_predict)
|
152 |
+
recall_one += np.sum(influencial_bone_correct)/np.sum(influencial_bone_gt)
|
153 |
+
|
154 |
+
precision += precision_one/len(skin_predict)
|
155 |
+
recall += recall_one/len(skin_predict)
|
156 |
+
count +=1
|
157 |
+
|
158 |
+
print('------------------------------------')
|
159 |
+
print('Evaluation results for skinning:')
|
160 |
+
print('cos loss: ', cs_loss/count)
|
161 |
+
print('ce loss: ', ce_loss/count)
|
162 |
+
print('cs_loss_clamp: ', cs_loss_clamp/count)
|
163 |
+
print('L1 loss: ', L1_loss/count)
|
164 |
+
print('mean_l1_dist: ', mean_l1_dist/count)
|
165 |
+
print('precision: ', precision/count)
|
166 |
+
print('recall: ', recall/count)
|
167 |
+
print('count: ', count)
|
168 |
+
print('------------------------------------')
|
169 |
+
|
170 |
+
def evaluate_skeleton(joints,joints_gt,conns,conns_gt,vox_list,fs_threshold=0.2):
|
171 |
+
|
172 |
+
"""
|
173 |
+
joints: list of predicted joints: tensor of shape (n,joints_num,3)
|
174 |
+
joints_gt: list of ground truth joints : tensor of shape (n,joints_num,3)
|
175 |
+
conns: list of predicted connections probability: tensor of shape (n,joints_num,joints_num)
|
176 |
+
conns_gt: list of ground truth connections: tensor of shape (n,joints_num,joints_num)
|
177 |
+
vox_list: list of voxel: (n,88,88,88)
|
178 |
+
"""
|
179 |
+
|
180 |
+
data_count = 0
|
181 |
+
chamfer_score = 0
|
182 |
+
j2b_chamfer_joint = 0
|
183 |
+
bone2bone_dist_joint = 0
|
184 |
+
edit_distance_joint = 0
|
185 |
+
joint_IoU_total = 0
|
186 |
+
joint_precision_total = 0
|
187 |
+
joint_recall_total = 0
|
188 |
+
|
189 |
+
for i in tqdm(range(len(joints))):
|
190 |
+
joint_predict = joints[i].cpu().numpy()
|
191 |
+
joint_gt = joints_gt[i].cpu().numpy()
|
192 |
+
conn_predict = conns[i].cpu().numpy()
|
193 |
+
conn_gt = conns_gt[i].cpu().numpy()
|
194 |
+
vox = vox_list[i]
|
195 |
+
|
196 |
+
# add shape diameter after we have vertex and faces
|
197 |
+
# shape_diameter = get_shape_diameter(mesh, points, parent_index[:,0])
|
198 |
+
|
199 |
+
dist_matrix = np.sqrt(np.sum((joint_predict[np.newaxis, ...] - joint_gt[:, np.newaxis, :]) ** 2, axis=2))
|
200 |
+
row_ind, col_ind = linear_sum_assignment(dist_matrix)
|
201 |
+
# fs_threshold = shape_diameter[row_ind]
|
202 |
+
joint_IoU = 2 * np.sum(dist_matrix[row_ind, col_ind] < fs_threshold) / (len(joint_predict) + len(joint_gt))
|
203 |
+
joint_IoU_total += joint_IoU
|
204 |
+
joint_precision = np.sum(dist_matrix[row_ind, col_ind] < fs_threshold) / len(joint_predict)
|
205 |
+
joint_precision_total += joint_precision
|
206 |
+
joint_recall = np.sum(dist_matrix[row_ind, col_ind] < fs_threshold) / len(joint_gt)
|
207 |
+
joint_recall_total += joint_recall
|
208 |
+
|
209 |
+
pred_skel_joint,parent_matrix = get_skel(joint_predict,conn_predict,vox=vox)
|
210 |
+
gt_skel, parent_matrix = get_skel(joint_gt,conn_gt,vox=vox)
|
211 |
+
chamfer_score += chamfer_dist(joint_predict, joint_gt)
|
212 |
+
j2b_chamfer_joint += joint2bone_chamfer_dist(pred_skel_joint, gt_skel)
|
213 |
+
bone2bone_dist_joint += bone2bone_chamfer_dist(pred_skel_joint, gt_skel)
|
214 |
+
edit_distance_joint += edit_dist(pred_skel_joint, gt_skel)
|
215 |
+
data_count+=1
|
216 |
+
|
217 |
+
print('------------------------------------')
|
218 |
+
print('Evaluation results for skeleton:')
|
219 |
+
print('chamfer_score:', chamfer_score/data_count)
|
220 |
+
print('j2b_chamfer_joint:', j2b_chamfer_joint/data_count)
|
221 |
+
print('bone2bone_dist_joint:', bone2bone_dist_joint/data_count)
|
222 |
+
print('joint_IoU:', joint_IoU_total/data_count)
|
223 |
+
print('joint_precision:', joint_precision_total/data_count)
|
224 |
+
print('joint_recall:', joint_recall_total/data_count)
|
225 |
+
print('------------------------------------')
|
Anymate/utils/loss_utils.py
ADDED
@@ -0,0 +1,56 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn.functional as F
|
3 |
+
import torch.nn as nn
|
4 |
+
def chamfer_distance_with_average(p1, p2):
|
5 |
+
|
6 |
+
'''
|
7 |
+
Calculate Chamfer Distance between two point sets
|
8 |
+
:param p1: size[1, N, D]
|
9 |
+
:param p2: size[1, M, D]
|
10 |
+
:param debug: whether need to output debug info
|
11 |
+
:return: sum of Chamfer Distance of two point sets
|
12 |
+
'''
|
13 |
+
|
14 |
+
assert p1.size(0) == 1 and p2.size(0) == 1
|
15 |
+
assert p1.size(2) == p2.size(2)
|
16 |
+
p1 = p1.repeat(p2.size(1), 1, 1)
|
17 |
+
p1 = p1.transpose(0, 1)
|
18 |
+
p2 = p2.repeat(p1.size(0), 1, 1)
|
19 |
+
dist = torch.add(p1, torch.neg(p2))
|
20 |
+
dist_norm = torch.norm(dist, 2, dim=2)
|
21 |
+
dist1 = torch.min(dist_norm, dim=1)[0]
|
22 |
+
dist2 = torch.min(dist_norm, dim=0)[0]
|
23 |
+
loss = 0.5 * ((torch.mean(dist1)) + (torch.mean(dist2)))
|
24 |
+
return loss
|
25 |
+
|
26 |
+
def cross_entropy_with_probs_batch(input, target, weight=None, reduction="mean"): # tested, same as nn.CrossEntropyLoss at dim=1, CE can be negative
|
27 |
+
# input_logsoftmax = F.log_softmax(input, dim=2)
|
28 |
+
input_logsoftmax = torch.log(input+1e-6)
|
29 |
+
cum_losses = -target * input_logsoftmax
|
30 |
+
if weight is not None:
|
31 |
+
cum_losses = cum_losses * weight.unsqueeze(1) # Broadcasting the weight
|
32 |
+
|
33 |
+
if reduction == "none":
|
34 |
+
return cum_losses
|
35 |
+
elif reduction == "mean":
|
36 |
+
return cum_losses.sum(dim=2).mean(dim=1).mean(dim=0)
|
37 |
+
elif reduction == "sum":
|
38 |
+
return cum_losses.sum(dim=2).sum(dim=1).mean(dim=0)
|
39 |
+
else:
|
40 |
+
raise ValueError("Keyword 'reduction' must be one of ['none', 'mean', 'sum']")
|
41 |
+
|
42 |
+
def cos_loss(input, target):
|
43 |
+
# input = F.softmax(input, dim=-1)
|
44 |
+
cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
45 |
+
similarity = cos(input, target)
|
46 |
+
loss = 1 - similarity.mean()
|
47 |
+
return loss
|
48 |
+
|
49 |
+
def cos_loss_clamp(input, target):
|
50 |
+
# input = F.softmax(input, dim=-1)*(1 + 2*0.001) - 0.001
|
51 |
+
input = input*(1 + 2*0.001) - 0.001
|
52 |
+
input = torch.clamp(input, 0, 1)
|
53 |
+
cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
|
54 |
+
similarity = cos(input, target)
|
55 |
+
loss = 1 - similarity.mean()
|
56 |
+
return loss
|
Anymate/utils/render_utils.py
ADDED
@@ -0,0 +1,1169 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import bpy
|
2 |
+
import numpy as np
|
3 |
+
from mathutils import Vector, Matrix
|
4 |
+
from tqdm import tqdm
|
5 |
+
import glob
|
6 |
+
import os
|
7 |
+
import torch
|
8 |
+
from PIL import Image
|
9 |
+
import matplotlib.pyplot as plt
|
10 |
+
cmap = plt.get_cmap('viridis')
|
11 |
+
import torch
|
12 |
+
import torchvision.io as io
|
13 |
+
import cv2
|
14 |
+
import trimesh
|
15 |
+
|
16 |
+
def get_data(ids, root, animate=False, shift_rig=True, id2=None, rignet=False):
|
17 |
+
dataset= torch.load('/data2/aod/testJointDataSet_9.pt')
|
18 |
+
joints = []
|
19 |
+
conns = []
|
20 |
+
skins = []
|
21 |
+
|
22 |
+
for id in ids:
|
23 |
+
if id2 is None:
|
24 |
+
for data in dataset:
|
25 |
+
if id in data['name']:
|
26 |
+
print(data['name'])
|
27 |
+
break
|
28 |
+
else:
|
29 |
+
for data in dataset:
|
30 |
+
if id2 in data['name']:
|
31 |
+
print(data['name'])
|
32 |
+
break
|
33 |
+
|
34 |
+
joint = torch.tensor(torch.load(root + '/joints/' + id + '.pt')).cpu()
|
35 |
+
if shift_rig and id2 is None:
|
36 |
+
y_max = data['points_cloud'][:,1].max()
|
37 |
+
joint = joint/2 + torch.tensor([0,y_max/2,0])
|
38 |
+
temp = joint[:, 1].clone()
|
39 |
+
joint[:, 1] = -joint[:, 2]
|
40 |
+
joint[:, 2] = temp
|
41 |
+
|
42 |
+
conn = torch.tensor(torch.load(root + '/connectivity/' + id + '.pt')).long()
|
43 |
+
if not animate:
|
44 |
+
skin = torch.load(root + '/skinning/' + id + '.pt')
|
45 |
+
if rignet:
|
46 |
+
skins.append(skin[0])
|
47 |
+
elif id2 is None:
|
48 |
+
skins.append(skin[0].softmax(dim=-1).cpu().numpy())
|
49 |
+
else:
|
50 |
+
skins.append(skin)
|
51 |
+
|
52 |
+
joints.append(joint)
|
53 |
+
conns.append(conn)
|
54 |
+
|
55 |
+
return joints, conns, skins
|
56 |
+
|
57 |
+
def index_to_sparse(index, weight, shape):
|
58 |
+
sparse_matrix = np.zeros([shape[0], shape[1], shape[2]+1])
|
59 |
+
|
60 |
+
row_indices, col_indices = np.meshgrid(np.arange(sparse_matrix.shape[0]), np.arange(sparse_matrix.shape[1]), indexing='ij')
|
61 |
+
|
62 |
+
row_indices = np.expand_dims(row_indices, axis=-1)
|
63 |
+
col_indices = np.expand_dims(col_indices, axis=-1)
|
64 |
+
|
65 |
+
sparse_matrix[row_indices, col_indices, index] = weight
|
66 |
+
|
67 |
+
|
68 |
+
return torch.from_numpy(sparse_matrix[:, :, :-1])
|
69 |
+
|
70 |
+
def get_gt(ids, root):
|
71 |
+
dataset= torch.load('/data2/aod/testJointDataSet_9.pt')
|
72 |
+
joints = []
|
73 |
+
conns = []
|
74 |
+
skins = []
|
75 |
+
|
76 |
+
for id in ids:
|
77 |
+
for data in dataset:
|
78 |
+
if id in data['name']:
|
79 |
+
print(data['name'])
|
80 |
+
break
|
81 |
+
|
82 |
+
joint = data['joints_matrix'][:data['joints_num'], :3]
|
83 |
+
y_max = data['points_cloud'][:,1].max()
|
84 |
+
joint = joint/2 + torch.tensor([0,y_max/2,0])
|
85 |
+
temp = joint[:, 1].clone()
|
86 |
+
joint[:, 1] = -joint[:, 2]
|
87 |
+
joint[:, 2] = temp
|
88 |
+
|
89 |
+
conn = data['parent_index'][:data['joints_num']].long().unsqueeze(1)
|
90 |
+
|
91 |
+
skin = index_to_sparse(data['skin_index'].unsqueeze(0), data['skin_weight'].unsqueeze(0), [1, 8192, data['joints_num']])
|
92 |
+
|
93 |
+
joints.append(joint)
|
94 |
+
conns.append(conn)
|
95 |
+
skins.append(skin[0])
|
96 |
+
|
97 |
+
return joints, conns, skins
|
98 |
+
|
99 |
+
def empty():
|
100 |
+
bpy.ops.wm.read_homefile(use_empty=True)
|
101 |
+
# Delete all mesh objects from the scene
|
102 |
+
# for obj in bpy.context.scene.objects:
|
103 |
+
# bpy.data.objects.remove(obj, do_unlink=True)
|
104 |
+
|
105 |
+
def add_mesh(filepath, co=None, tex=False, color=(0.5, 0.5, 0.5, 1)):
|
106 |
+
bpy.ops.wm.obj_import(filepath=filepath)
|
107 |
+
obj = bpy.context.object
|
108 |
+
|
109 |
+
if not tex:
|
110 |
+
# give the mesh a material
|
111 |
+
bpy.context.view_layer.objects.active = obj
|
112 |
+
bpy.ops.object.shade_smooth()
|
113 |
+
bpy.ops.object.mode_set(mode='EDIT')
|
114 |
+
bpy.ops.mesh.select_all(action='SELECT')
|
115 |
+
bpy.ops.mesh.normals_make_consistent(inside=False)
|
116 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
117 |
+
mat = bpy.data.materials.new(name='mat')
|
118 |
+
obj.data.materials.clear()
|
119 |
+
obj.data.materials.append(mat)
|
120 |
+
mat.use_nodes = True
|
121 |
+
mat.node_tree.nodes.clear()
|
122 |
+
bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
|
123 |
+
output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
124 |
+
mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
125 |
+
mat.node_tree.nodes['Principled BSDF'].inputs['Roughness'].default_value = 0.8
|
126 |
+
# mat.node_tree.nodes['Principled BSDF'].inputs['Specular'].default_value = 0.5
|
127 |
+
# mat.node_tree.nodes['Principled BSDF'].inputs['Metallic'].default_value = 0.5
|
128 |
+
mat.node_tree.nodes['Principled BSDF'].inputs['Base Color'].default_value = color
|
129 |
+
if co is not None:
|
130 |
+
obj.parent = co
|
131 |
+
|
132 |
+
def create_sphere(location, size=0.01, color=(1.0, 0.0, 0.0, 1.0), reduced=False):
|
133 |
+
if reduced:
|
134 |
+
bpy.ops.mesh.primitive_uv_sphere_add(radius=size, location=location, segments=8, ring_count=4)
|
135 |
+
else:
|
136 |
+
bpy.ops.mesh.primitive_uv_sphere_add(radius=size, location=location)
|
137 |
+
sphere = bpy.context.active_object
|
138 |
+
|
139 |
+
material_name = f"ColorMaterial_{color}"
|
140 |
+
material = bpy.data.materials.get(material_name)
|
141 |
+
|
142 |
+
if not material:
|
143 |
+
material = bpy.data.materials.new(name=material_name)
|
144 |
+
material.use_nodes = True
|
145 |
+
material.node_tree.nodes.clear()
|
146 |
+
bsdf = material.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
|
147 |
+
output = material.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
148 |
+
material.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
149 |
+
material.node_tree.nodes['Principled BSDF'].inputs['Base Color'].default_value = color
|
150 |
+
|
151 |
+
sphere.data.materials.append(material)
|
152 |
+
|
153 |
+
return sphere
|
154 |
+
|
155 |
+
def add_co(location=(0,0,0), rotation=(0,0,0), scale=(1,1,1)):
|
156 |
+
co = bpy.data.objects.new("CoordinateSystem", None)
|
157 |
+
bpy.context.collection.objects.link(co)
|
158 |
+
bpy.context.view_layer.objects.active = co
|
159 |
+
co.empty_display_size = 0.1
|
160 |
+
co.empty_display_type = 'ARROWS'
|
161 |
+
co.location = location
|
162 |
+
co.rotation_euler = rotation
|
163 |
+
co.scale = scale
|
164 |
+
|
165 |
+
return co
|
166 |
+
|
167 |
+
def add_joint(joints_matrix, co=None):
|
168 |
+
|
169 |
+
for i, joint in enumerate(joints_matrix):
|
170 |
+
sphere = create_sphere((joint[0], joint[1], joint[2]), size=0.01)
|
171 |
+
if co is not None:
|
172 |
+
sphere.parent = co
|
173 |
+
|
174 |
+
def create_blue_cone(base_point, apex_point, radius=0.1):
|
175 |
+
# Calculate the radius and length of the cone
|
176 |
+
direction = apex_point - base_point
|
177 |
+
length = direction.length
|
178 |
+
|
179 |
+
# Create cone mesh
|
180 |
+
bpy.ops.mesh.primitive_cone_add(vertices=32, radius1=radius, depth=length, location=(base_point + direction * 0.5))
|
181 |
+
cone = bpy.context.active_object
|
182 |
+
|
183 |
+
# Create or get the blue material
|
184 |
+
blue_material = bpy.data.materials.get("BlueMaterial")
|
185 |
+
if not blue_material:
|
186 |
+
blue_material = bpy.data.materials.new(name="BlueMaterial")
|
187 |
+
blue_material.use_nodes = True
|
188 |
+
blue_material.node_tree.nodes.clear()
|
189 |
+
bsdf = blue_material.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
|
190 |
+
output = blue_material.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
191 |
+
blue_material.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
192 |
+
blue_material.node_tree.nodes['Principled BSDF'].inputs['Base Color'].default_value = (0.0, 0.0, 1.0, 1.0)
|
193 |
+
|
194 |
+
cone.data.materials.append(blue_material)
|
195 |
+
|
196 |
+
# Set the cone's orientation
|
197 |
+
cone.rotation_euler = direction.to_track_quat('Z', 'Y').to_euler()
|
198 |
+
|
199 |
+
return cone
|
200 |
+
|
201 |
+
def add_conn(con_index, joints_matrix, co=None):
|
202 |
+
for i, parent in enumerate(con_index):
|
203 |
+
parent = parent.item()
|
204 |
+
if parent != i:
|
205 |
+
parent_co = Vector((joints_matrix[parent][0], joints_matrix[parent][1], joints_matrix[parent][2]))
|
206 |
+
position = Vector((joints_matrix[i][0], joints_matrix[i][1], joints_matrix[i][2]))
|
207 |
+
cone = create_blue_cone(parent_co, position, radius=0.008)
|
208 |
+
if co is not None:
|
209 |
+
cone.parent = co
|
210 |
+
|
211 |
+
def merge_images(img1, img2, output_path, alpha=1):
|
212 |
+
image_mesh = Image.open(img1)
|
213 |
+
image_rig = Image.open(img2)
|
214 |
+
|
215 |
+
if alpha == 1:
|
216 |
+
image_mesh.paste(image_rig, (0, 0), image_rig)
|
217 |
+
image_mesh.save(output_path)
|
218 |
+
return
|
219 |
+
|
220 |
+
data = image_rig.getdata()
|
221 |
+
data2 = image_mesh.getdata()
|
222 |
+
new_data = []
|
223 |
+
for item, item2 in zip(data, data2):
|
224 |
+
if item[3] == 0:
|
225 |
+
new_data.append(item2)
|
226 |
+
else:
|
227 |
+
new_data.append((int(item[0]*alpha + item2[0]*(1-alpha)), int(item[1]*alpha + item2[1]*(1-alpha)), int(item[2]*alpha + item2[2]*(1-alpha)), 255))
|
228 |
+
image_mesh.putdata(new_data)
|
229 |
+
|
230 |
+
# image_mesh.paste(image_rig, (0, 0), image_rig)
|
231 |
+
|
232 |
+
image_mesh.save(output_path)
|
233 |
+
|
234 |
+
def merge_videos(video1, video2, output_path):
|
235 |
+
|
236 |
+
# overlap two videos together, video1 is the background, video2 is the foreground
|
237 |
+
# os.system(f'ffmpeg -i {video1} -i {video2} -filter_complex "[0:v][1:v] overlay=0:0:enable=\'between(t,0,60)\'" -pix_fmt yuv420p -c:a copy {output_path}')
|
238 |
+
|
239 |
+
frames_path_1 = glob.glob(video1 + '*.png')
|
240 |
+
total_frames = len(frames_path_1)
|
241 |
+
combined_frames = []
|
242 |
+
for i in range(total_frames):
|
243 |
+
frame1 = Image.open(f'{video1}{i:04d}.png')
|
244 |
+
frame2 = Image.open(f'{video2}{i:04d}.png')
|
245 |
+
frame1.paste(frame2, (0, 0), frame2)
|
246 |
+
combined_frames.append(frame1)
|
247 |
+
|
248 |
+
# paste the combined frames on a pure white background
|
249 |
+
combined_frames_white = []
|
250 |
+
for frame in combined_frames:
|
251 |
+
white = Image.new('RGB', frame.size, (255, 255, 255))
|
252 |
+
white.paste(frame, (0, 0), frame)
|
253 |
+
combined_frames_white.append(white)
|
254 |
+
|
255 |
+
combined_frames=combined_frames_white
|
256 |
+
|
257 |
+
combined_videos = torch.stack([torch.tensor(np.array(frame)) for frame in combined_frames])[..., :3]
|
258 |
+
|
259 |
+
# write the video with high quality
|
260 |
+
# io.write_video(output_path, combined_videos, 24)
|
261 |
+
io.write_video(output_path, combined_videos, 24, video_codec='libx264', options={'crf': '18'})
|
262 |
+
|
263 |
+
# comvert the frames to mp4 video
|
264 |
+
|
265 |
+
# video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'H264'), 30, (frame1.size[0], frame1.size[1]))
|
266 |
+
# for frame in combined_frames:
|
267 |
+
# video.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
|
268 |
+
# video.release()
|
269 |
+
|
270 |
+
# video_1, audio_1, fps_1 = io.read_video(video1, pts_unit="sec")
|
271 |
+
# video_2, audio_2, fps_2 = io.read_video(video2, pts_unit="sec")
|
272 |
+
# non_zero = video_2.sum(dim=-1) != 0
|
273 |
+
# non_zero = torch.stack([non_zero, non_zero, non_zero], dim=-1)
|
274 |
+
# video_1[non_zero] = video_2[non_zero]
|
275 |
+
# io.write_video(output_path, video_1, int(fps_1['video_fps']))
|
276 |
+
|
277 |
+
def add_skin(filepath, skin, bone_index, co=None, pc=None):
|
278 |
+
bpy.ops.wm.obj_import(filepath=filepath)
|
279 |
+
obj = bpy.context.object
|
280 |
+
|
281 |
+
bpy.context.view_layer.objects.active = obj
|
282 |
+
bpy.ops.object.shade_smooth()
|
283 |
+
bpy.ops.object.mode_set(mode='EDIT')
|
284 |
+
bpy.ops.mesh.select_all(action='SELECT')
|
285 |
+
bpy.ops.mesh.normals_make_consistent(inside=False)
|
286 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
287 |
+
|
288 |
+
if co is not None:
|
289 |
+
obj.parent = co
|
290 |
+
|
291 |
+
if pc is not None:
|
292 |
+
skin = np.array(skin)
|
293 |
+
pc = pc[:, :3].numpy()
|
294 |
+
y_max = pc[:, 1].max()
|
295 |
+
pc = pc + np.array([0, y_max, 0])
|
296 |
+
pc = pc / 2
|
297 |
+
new_skin = np.zeros((len(obj.data.vertices), skin.shape[1]))
|
298 |
+
for i, v in enumerate(obj.data.vertices):
|
299 |
+
v_co = np.array(v.co)
|
300 |
+
|
301 |
+
dist = np.linalg.norm(pc - v_co, axis=1)
|
302 |
+
# min_idx = np.argmin(dist)
|
303 |
+
# sort, and then get top 3 index
|
304 |
+
min_idx_list = np.argsort(dist)[:3]
|
305 |
+
|
306 |
+
for min_idx in min_idx_list:
|
307 |
+
# get inverse distance weight
|
308 |
+
interpolate_weight = np.square(1 / dist[min_idx]) / np.square(1 / dist[min_idx_list]).sum()
|
309 |
+
new_skin[i] = new_skin[i] + interpolate_weight * skin[min_idx]
|
310 |
+
|
311 |
+
skin = new_skin
|
312 |
+
|
313 |
+
color_list = skin
|
314 |
+
|
315 |
+
color_list = color_list[:,bone_index]
|
316 |
+
|
317 |
+
vertex_colors = obj.data.vertex_colors.new()
|
318 |
+
|
319 |
+
for poly in obj.data.polygons:
|
320 |
+
for loop_index in poly.loop_indices:
|
321 |
+
|
322 |
+
vertex_index = obj.data.loops[loop_index].vertex_index
|
323 |
+
# Get the weight for the vertex
|
324 |
+
weight = color_list[vertex_index]
|
325 |
+
|
326 |
+
color = cmap(weight)
|
327 |
+
|
328 |
+
# Assign the weight to the vertex color (RGBA)
|
329 |
+
vertex_colors.data[loop_index].color = color # Use the weight for RGB
|
330 |
+
|
331 |
+
# let bsdf use vertex color and then output to surface
|
332 |
+
mat = bpy.data.materials.new(name='mat')
|
333 |
+
# delete all material of obj
|
334 |
+
obj.data.materials.clear()
|
335 |
+
obj.data.materials.append(mat)
|
336 |
+
mat.use_nodes = True
|
337 |
+
mat.node_tree.nodes.clear()
|
338 |
+
vertex_color = mat.node_tree.nodes.new('ShaderNodeVertexColor')
|
339 |
+
bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
|
340 |
+
output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
341 |
+
mat.node_tree.links.new(vertex_color.outputs['Color'], bsdf.inputs['Base Color'])
|
342 |
+
mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
343 |
+
mat.node_tree.nodes['Principled BSDF'].inputs['Roughness'].default_value = 0.5
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
def add_pc(points):
|
348 |
+
base_sphere = create_sphere((points[0][0], points[0][1], points[0][2]), size=0.003, color=cmap(0), reduced=True)
|
349 |
+
# copy the base sphere to create the rest of the spheres
|
350 |
+
for i in tqdm(range(1, points.shape[0])):
|
351 |
+
new_sphere = base_sphere.copy()
|
352 |
+
new_sphere.location = (points[i][0], points[i][1], points[i][2])
|
353 |
+
bpy.context.collection.objects.link(new_sphere)
|
354 |
+
|
355 |
+
def add_floor(back=False):
|
356 |
+
# create a plane as floor
|
357 |
+
bpy.ops.mesh.primitive_plane_add(size=50, enter_editmode=False, align='WORLD', location=(0, 20, 0))
|
358 |
+
floor = bpy.context.object
|
359 |
+
floor.name = 'floor'
|
360 |
+
# set white material for floor
|
361 |
+
mat = bpy.data.materials.new(name='floor_mat')
|
362 |
+
floor.data.materials.append(mat)
|
363 |
+
mat.use_nodes = True
|
364 |
+
mat.node_tree.nodes.clear()
|
365 |
+
bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
|
366 |
+
output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
367 |
+
mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
368 |
+
mat.node_tree.nodes['Diffuse BSDF'].inputs['Color'].default_value = (1, 1, 1, 1)
|
369 |
+
|
370 |
+
if back:
|
371 |
+
# create a plane as background
|
372 |
+
bpy.ops.mesh.primitive_plane_add(size=30, enter_editmode=False, align='WORLD', location=(0, 15, 0), rotation=(-0.5*np.pi, 0, 0))
|
373 |
+
background = bpy.context.object
|
374 |
+
background.name = 'background'
|
375 |
+
# set white material for background
|
376 |
+
mat = bpy.data.materials.new(name='background_mat')
|
377 |
+
background.data.materials.append(mat)
|
378 |
+
mat.use_nodes = True
|
379 |
+
mat.node_tree.nodes.clear()
|
380 |
+
bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
|
381 |
+
output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
|
382 |
+
mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
|
383 |
+
mat.node_tree.nodes['Diffuse BSDF'].inputs['Color'].default_value = (1, 1, 1, 1)
|
384 |
+
|
385 |
+
def setup_render():
|
386 |
+
# color management
|
387 |
+
bpy.context.scene.view_settings.view_transform = 'Standard'
|
388 |
+
|
389 |
+
# set the render engine to Cycles
|
390 |
+
bpy.context.scene.render.engine = 'CYCLES'
|
391 |
+
# enable cuda
|
392 |
+
bpy.context.preferences.addons['cycles'].preferences.get_devices()
|
393 |
+
bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
|
394 |
+
bpy.context.scene.cycles.device = 'GPU'
|
395 |
+
|
396 |
+
# set render background to transparent
|
397 |
+
bpy.context.scene.render.film_transparent = True
|
398 |
+
|
399 |
+
def render(output_path, shadow=True, shading=True, quick=False):
|
400 |
+
|
401 |
+
if shadow:
|
402 |
+
add_floor()
|
403 |
+
|
404 |
+
if shading:
|
405 |
+
# create a sun light
|
406 |
+
bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
|
407 |
+
light = bpy.context.object
|
408 |
+
light.data.energy = 5
|
409 |
+
# angle pointing to the origin
|
410 |
+
light.rotation_euler = (0.1*np.pi, 0, 0)
|
411 |
+
# set angle
|
412 |
+
light.data.angle = 0.08*np.pi
|
413 |
+
|
414 |
+
else:
|
415 |
+
# global illumination by create world light
|
416 |
+
world = bpy.data.worlds.new('World')
|
417 |
+
bpy.context.scene.world = world
|
418 |
+
world.use_nodes = True
|
419 |
+
world_light = world.node_tree.nodes['Background']
|
420 |
+
world_light.inputs['Strength'].default_value = 1
|
421 |
+
world_light.inputs['Color'].default_value = (1, 1, 1, 1)
|
422 |
+
|
423 |
+
# create a camera
|
424 |
+
cam = bpy.data.cameras.new("Camera")
|
425 |
+
cam_ob = bpy.data.objects.new("Camera", cam)
|
426 |
+
camera = bpy.data.objects['Camera']
|
427 |
+
bpy.context.scene.collection.objects.link(camera)
|
428 |
+
camera.location = Vector((2, -1.5, 2))
|
429 |
+
look_at = Vector((0, 0, 0.36))
|
430 |
+
# compute the rotation
|
431 |
+
camera.rotation_mode = 'QUATERNION'
|
432 |
+
camera.rotation_quaternion = (camera.location - look_at).to_track_quat('Z', 'Y')
|
433 |
+
# set size
|
434 |
+
camera.data.sensor_width = 26
|
435 |
+
# set the camera to be active
|
436 |
+
bpy.context.scene.camera = camera
|
437 |
+
|
438 |
+
|
439 |
+
|
440 |
+
# make the rendered image square
|
441 |
+
bpy.context.scene.render.resolution_x = 2048
|
442 |
+
bpy.context.scene.render.resolution_y = 2048
|
443 |
+
|
444 |
+
setup_render()
|
445 |
+
|
446 |
+
if quick:
|
447 |
+
# reduce the number of samples
|
448 |
+
bpy.context.scene.cycles.samples = 128
|
449 |
+
bpy.context.scene.cycles.preview_samples = 128
|
450 |
+
bpy.context.scene.cycles.max_bounces = 1
|
451 |
+
bpy.context.scene.cycles.min_bounces = 1
|
452 |
+
bpy.context.scene.cycles.diffuse_bounces = 1
|
453 |
+
bpy.context.scene.cycles.glossy_bounces = 1
|
454 |
+
else:
|
455 |
+
bpy.context.scene.cycles.samples = 1024
|
456 |
+
bpy.context.scene.cycles.preview_samples = 1024
|
457 |
+
bpy.context.scene.cycles.max_bounces = 4
|
458 |
+
bpy.context.scene.cycles.min_bounces = 4
|
459 |
+
bpy.context.scene.cycles.diffuse_bounces = 4
|
460 |
+
bpy.context.scene.cycles.glossy_bounces = 4
|
461 |
+
|
462 |
+
# output path
|
463 |
+
# output_path = '/home/ydengbd/objaverse/test.png'
|
464 |
+
bpy.context.scene.render.filepath = output_path
|
465 |
+
bpy.ops.render.render(write_still=True)
|
466 |
+
|
467 |
+
def render_spin(output_path, co, shadow=True, shading=True, quick=False):
|
468 |
+
# create a new coordinate system at the origin
|
469 |
+
new_co = add_co(location=(0, 0, 0), rotation=(0, 0, 0), scale=(1, 1, 1))
|
470 |
+
# set the object to be the child of the new coordinate system
|
471 |
+
co.parent = new_co
|
472 |
+
|
473 |
+
# add spin animation to the new coordinate system
|
474 |
+
new_co.rotation_mode = 'XYZ'
|
475 |
+
new_co.rotation_euler = (0, 0, 0)
|
476 |
+
new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=0)
|
477 |
+
new_co.rotation_euler = (0, 0, 2*np.pi)
|
478 |
+
new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=60)
|
479 |
+
|
480 |
+
if shadow:
|
481 |
+
add_floor()
|
482 |
+
|
483 |
+
if shading:
|
484 |
+
# create a sun light
|
485 |
+
bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
|
486 |
+
light = bpy.context.object
|
487 |
+
light.data.energy = 5
|
488 |
+
# angle pointing to the origin
|
489 |
+
light.rotation_euler = (0.1*np.pi, 0, 0)
|
490 |
+
# set angle
|
491 |
+
light.data.angle = 0.08*np.pi
|
492 |
+
|
493 |
+
else:
|
494 |
+
# global illumination by create world light
|
495 |
+
world = bpy.data.worlds.new('World')
|
496 |
+
bpy.context.scene.world = world
|
497 |
+
world.use_nodes = True
|
498 |
+
world_light = world.node_tree.nodes['Background']
|
499 |
+
world_light.inputs['Strength'].default_value = 1
|
500 |
+
world_light.inputs['Color'].default_value = (1, 1, 1, 1)
|
501 |
+
|
502 |
+
# create a camera
|
503 |
+
cam = bpy.data.cameras.new("Camera")
|
504 |
+
cam_ob = bpy.data.objects.new("Camera", cam)
|
505 |
+
camera = bpy.data.objects['Camera']
|
506 |
+
bpy.context.scene.collection.objects.link(camera)
|
507 |
+
camera.location = Vector((2, -1.5, 2))
|
508 |
+
look_at = Vector((0, 0, 0.36))
|
509 |
+
# compute the rotation
|
510 |
+
camera.rotation_mode = 'QUATERNION'
|
511 |
+
camera.rotation_quaternion = (camera.location - look_at).to_track_quat('Z', 'Y')
|
512 |
+
# set size
|
513 |
+
camera.data.sensor_width = 26
|
514 |
+
# set the camera to be active
|
515 |
+
bpy.context.scene.camera = camera
|
516 |
+
|
517 |
+
|
518 |
+
# render the animation
|
519 |
+
bpy.context.scene.frame_start = 0
|
520 |
+
bpy.context.scene.frame_end = 60
|
521 |
+
|
522 |
+
# make the rendered image square
|
523 |
+
bpy.context.scene.render.resolution_x = 1024
|
524 |
+
bpy.context.scene.render.resolution_y = 1024
|
525 |
+
|
526 |
+
setup_render()
|
527 |
+
|
528 |
+
if quick:
|
529 |
+
# reduce the number of samples
|
530 |
+
bpy.context.scene.cycles.samples = 128
|
531 |
+
bpy.context.scene.cycles.preview_samples = 128
|
532 |
+
bpy.context.scene.cycles.max_bounces = 1
|
533 |
+
bpy.context.scene.cycles.min_bounces = 1
|
534 |
+
bpy.context.scene.cycles.diffuse_bounces = 1
|
535 |
+
bpy.context.scene.cycles.glossy_bounces = 1
|
536 |
+
else:
|
537 |
+
bpy.context.scene.cycles.samples = 512
|
538 |
+
bpy.context.scene.cycles.preview_samples = 512
|
539 |
+
bpy.context.scene.cycles.max_bounces = 4
|
540 |
+
bpy.context.scene.cycles.min_bounces = 4
|
541 |
+
bpy.context.scene.cycles.diffuse_bounces = 4
|
542 |
+
bpy.context.scene.cycles.glossy_bounces = 4
|
543 |
+
|
544 |
+
# output path
|
545 |
+
bpy.context.scene.render.filepath = output_path
|
546 |
+
if output_path.endswith('.mp4'):
|
547 |
+
# render a mp4 video
|
548 |
+
bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
|
549 |
+
bpy.context.scene.render.ffmpeg.format = 'MPEG4'
|
550 |
+
bpy.context.scene.render.ffmpeg.codec = 'H264'
|
551 |
+
|
552 |
+
bpy.ops.render.render(animation=True, write_still=True)
|
553 |
+
|
554 |
+
def setup_anim(armature, arti):
|
555 |
+
# enter pose mode
|
556 |
+
print('Arti shape', arti.shape)
|
557 |
+
bpy.ops.object.mode_set(mode='POSE')
|
558 |
+
print('total bones', len(armature.pose.bones))
|
559 |
+
for i, pose_bone in enumerate(armature.pose.bones):
|
560 |
+
pose_bone.rotation_mode = 'XYZ'
|
561 |
+
pose_bone.keyframe_insert(data_path="rotation_euler", frame=0)
|
562 |
+
|
563 |
+
pose_bone.rotation_euler = arti[i]
|
564 |
+
pose_bone.keyframe_insert(data_path="rotation_euler", frame=30)
|
565 |
+
|
566 |
+
pose_bone.rotation_euler = Vector((0, 0, 0))
|
567 |
+
pose_bone.keyframe_insert(data_path="rotation_euler", frame=60)
|
568 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
569 |
+
|
570 |
+
def render_anim(output_path, armature, arti, quick=False):
|
571 |
+
# enter pose mode
|
572 |
+
setup_anim(armature, arti)
|
573 |
+
|
574 |
+
# save blend file
|
575 |
+
# bpy.ops.wm.save_as_mainfile(filepath='/data2/ydengbd/objaverse/test.blend')
|
576 |
+
|
577 |
+
add_floor()
|
578 |
+
|
579 |
+
# create a sun light
|
580 |
+
bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
|
581 |
+
light = bpy.context.object
|
582 |
+
light.data.energy = 5
|
583 |
+
# angle pointing to the origin
|
584 |
+
light.rotation_euler = (50/180*np.pi, 0, -20/180*np.pi)
|
585 |
+
# set angle
|
586 |
+
light.data.angle = 12/180*np.pi
|
587 |
+
|
588 |
+
# create a camera
|
589 |
+
cam = bpy.data.cameras.new("Camera")
|
590 |
+
cam_ob = bpy.data.objects.new("Camera", cam)
|
591 |
+
camera = bpy.data.objects['Camera']
|
592 |
+
bpy.context.scene.collection.objects.link(camera)
|
593 |
+
camera.location = Vector((0, -3, 1.3))
|
594 |
+
camera.rotation_euler = Vector((1.309, 0, 0))
|
595 |
+
# set size
|
596 |
+
camera.data.sensor_width = 36
|
597 |
+
# set the camera to be active
|
598 |
+
bpy.context.scene.camera = camera
|
599 |
+
|
600 |
+
# render the animation
|
601 |
+
bpy.context.scene.frame_start = 0
|
602 |
+
bpy.context.scene.frame_end = 60
|
603 |
+
|
604 |
+
# make the rendered image square
|
605 |
+
bpy.context.scene.render.resolution_x = 1920
|
606 |
+
bpy.context.scene.render.resolution_y = 1080
|
607 |
+
|
608 |
+
setup_render()
|
609 |
+
|
610 |
+
if quick:
|
611 |
+
# reduce the number of samples
|
612 |
+
bpy.context.scene.cycles.samples = 128
|
613 |
+
bpy.context.scene.cycles.preview_samples = 128
|
614 |
+
bpy.context.scene.cycles.max_bounces = 1
|
615 |
+
bpy.context.scene.cycles.min_bounces = 1
|
616 |
+
bpy.context.scene.cycles.diffuse_bounces = 1
|
617 |
+
bpy.context.scene.cycles.glossy_bounces = 1
|
618 |
+
else:
|
619 |
+
bpy.context.scene.cycles.samples = 1024
|
620 |
+
bpy.context.scene.cycles.preview_samples = 1024
|
621 |
+
bpy.context.scene.cycles.max_bounces = 4
|
622 |
+
bpy.context.scene.cycles.min_bounces = 4
|
623 |
+
bpy.context.scene.cycles.diffuse_bounces = 4
|
624 |
+
bpy.context.scene.cycles.glossy_bounces = 4
|
625 |
+
|
626 |
+
# output path
|
627 |
+
bpy.context.scene.render.filepath = output_path
|
628 |
+
if output_path.endswith('.mp4'):
|
629 |
+
# render a mp4 video
|
630 |
+
bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
|
631 |
+
bpy.context.scene.render.ffmpeg.format = 'MPEG4'
|
632 |
+
bpy.context.scene.render.ffmpeg.codec = 'H264'
|
633 |
+
|
634 |
+
bpy.ops.render.render(animation=True, write_still=True)
|
635 |
+
|
636 |
+
|
637 |
+
def render_animspin(output_path, co, armature, arti, shadow=True, shading=True, quick=False):
|
638 |
+
# enter pose mode
|
639 |
+
print('Arti shape', arti.shape)
|
640 |
+
bpy.ops.object.mode_set(mode='POSE')
|
641 |
+
print('total bones', len(armature.pose.bones))
|
642 |
+
for i, pose_bone in enumerate(armature.pose.bones):
|
643 |
+
pose_bone.rotation_mode = 'XYZ'
|
644 |
+
pose_bone.keyframe_insert(data_path="rotation_euler", frame=0)
|
645 |
+
|
646 |
+
pose_bone.rotation_euler = arti[i]
|
647 |
+
pose_bone.keyframe_insert(data_path="rotation_euler", frame=30)
|
648 |
+
|
649 |
+
pose_bone.rotation_euler = Vector((0, 0, 0))
|
650 |
+
pose_bone.keyframe_insert(data_path="rotation_euler", frame=60)
|
651 |
+
|
652 |
+
pose_bone.rotation_euler = arti[i]
|
653 |
+
pose_bone.keyframe_insert(data_path="rotation_euler", frame=90)
|
654 |
+
pose_bone.keyframe_insert(data_path="rotation_euler", frame=150)
|
655 |
+
|
656 |
+
pose_bone.rotation_euler = Vector((0, 0, 0))
|
657 |
+
pose_bone.keyframe_insert(data_path="rotation_euler", frame=180)
|
658 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
659 |
+
|
660 |
+
# create a new coordinate system at the origin
|
661 |
+
new_co = add_co(location=(0, 0, 0), rotation=(0, 0, 0), scale=(1, 1, 1))
|
662 |
+
# set the object to be the child of the new coordinate system
|
663 |
+
co.parent = new_co
|
664 |
+
|
665 |
+
# add spin animation to the new coordinate system
|
666 |
+
new_co.rotation_mode = 'XYZ'
|
667 |
+
new_co.rotation_euler = (0, 0, 0)
|
668 |
+
new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=90)
|
669 |
+
new_co.rotation_euler = (0, 0, 2*np.pi)
|
670 |
+
new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=150)
|
671 |
+
|
672 |
+
if shadow:
|
673 |
+
add_floor()
|
674 |
+
|
675 |
+
if shading:
|
676 |
+
# create a sun light
|
677 |
+
bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
|
678 |
+
light = bpy.context.object
|
679 |
+
light.data.energy = 5
|
680 |
+
# angle pointing to the origin
|
681 |
+
light.rotation_euler = (0.1*np.pi, 0, 0)
|
682 |
+
# set angle
|
683 |
+
light.data.angle = 0.08*np.pi
|
684 |
+
|
685 |
+
else:
|
686 |
+
# global illumination by create world light
|
687 |
+
world = bpy.data.worlds.new('World')
|
688 |
+
bpy.context.scene.world = world
|
689 |
+
world.use_nodes = True
|
690 |
+
world_light = world.node_tree.nodes['Background']
|
691 |
+
world_light.inputs['Strength'].default_value = 1
|
692 |
+
world_light.inputs['Color'].default_value = (1, 1, 1, 1)
|
693 |
+
|
694 |
+
# create a camera
|
695 |
+
cam = bpy.data.cameras.new("Camera")
|
696 |
+
cam_ob = bpy.data.objects.new("Camera", cam)
|
697 |
+
camera = bpy.data.objects['Camera']
|
698 |
+
bpy.context.scene.collection.objects.link(camera)
|
699 |
+
camera.location = Vector((2, -1.5, 2))
|
700 |
+
look_at = Vector((0, 0, 0.36))
|
701 |
+
# compute the rotation
|
702 |
+
camera.rotation_mode = 'QUATERNION'
|
703 |
+
camera.rotation_quaternion = (camera.location - look_at).to_track_quat('Z', 'Y')
|
704 |
+
# set size
|
705 |
+
camera.data.sensor_width = 26
|
706 |
+
# set the camera to be active
|
707 |
+
bpy.context.scene.camera = camera
|
708 |
+
|
709 |
+
|
710 |
+
# render the animation
|
711 |
+
bpy.context.scene.frame_start = 0
|
712 |
+
bpy.context.scene.frame_end = 180
|
713 |
+
|
714 |
+
# make the rendered image square
|
715 |
+
bpy.context.scene.render.resolution_x = 1024
|
716 |
+
bpy.context.scene.render.resolution_y = 1024
|
717 |
+
|
718 |
+
setup_render()
|
719 |
+
|
720 |
+
if quick:
|
721 |
+
# reduce the number of samples
|
722 |
+
bpy.context.scene.cycles.samples = 128
|
723 |
+
bpy.context.scene.cycles.preview_samples = 128
|
724 |
+
bpy.context.scene.cycles.max_bounces = 1
|
725 |
+
bpy.context.scene.cycles.min_bounces = 1
|
726 |
+
bpy.context.scene.cycles.diffuse_bounces = 1
|
727 |
+
bpy.context.scene.cycles.glossy_bounces = 1
|
728 |
+
else:
|
729 |
+
bpy.context.scene.cycles.samples = 512
|
730 |
+
bpy.context.scene.cycles.preview_samples = 512
|
731 |
+
bpy.context.scene.cycles.max_bounces = 4
|
732 |
+
bpy.context.scene.cycles.min_bounces = 4
|
733 |
+
bpy.context.scene.cycles.diffuse_bounces = 4
|
734 |
+
bpy.context.scene.cycles.glossy_bounces = 4
|
735 |
+
|
736 |
+
# output path
|
737 |
+
bpy.context.scene.render.filepath = output_path
|
738 |
+
if output_path.endswith('.mp4'):
|
739 |
+
# render a mp4 video
|
740 |
+
bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
|
741 |
+
bpy.context.scene.render.ffmpeg.format = 'MPEG4'
|
742 |
+
bpy.context.scene.render.ffmpeg.codec = 'H264'
|
743 |
+
|
744 |
+
bpy.ops.render.render(animation=True, write_still=True)
|
745 |
+
|
746 |
+
def render_scene(output_path, shadow=True):
|
747 |
+
|
748 |
+
if shadow:
|
749 |
+
add_floor()
|
750 |
+
|
751 |
+
|
752 |
+
# create a sun light
|
753 |
+
bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
|
754 |
+
light = bpy.context.object
|
755 |
+
light.data.energy = 5
|
756 |
+
# angle pointing to the origin
|
757 |
+
light.rotation_euler = (50/180*np.pi, 0, -20/180*np.pi)
|
758 |
+
# set angle
|
759 |
+
light.data.angle = 12/180*np.pi
|
760 |
+
|
761 |
+
# create a camera
|
762 |
+
cam = bpy.data.cameras.new("Camera")
|
763 |
+
cam_ob = bpy.data.objects.new("Camera", cam)
|
764 |
+
camera = bpy.data.objects['Camera']
|
765 |
+
bpy.context.scene.collection.objects.link(camera)
|
766 |
+
camera.location = Vector((0, -10, 5))
|
767 |
+
camera.rotation_euler = Vector((1.22, 0, 0))
|
768 |
+
# set size
|
769 |
+
camera.data.sensor_width = 26
|
770 |
+
# set the camera to be active
|
771 |
+
bpy.context.scene.camera = camera
|
772 |
+
|
773 |
+
|
774 |
+
|
775 |
+
# make the rendered image square
|
776 |
+
bpy.context.scene.render.resolution_x = 1920
|
777 |
+
bpy.context.scene.render.resolution_y = 1080
|
778 |
+
|
779 |
+
setup_render()
|
780 |
+
|
781 |
+
|
782 |
+
|
783 |
+
# output path
|
784 |
+
# output_path = '/home/ydengbd/objaverse/test.png'
|
785 |
+
bpy.context.scene.render.filepath = output_path
|
786 |
+
bpy.ops.render.render(write_still=True)
|
787 |
+
|
788 |
+
|
789 |
+
def render_teaser(output_path, shadow=True, quick=False):
|
790 |
+
|
791 |
+
if shadow:
|
792 |
+
add_floor(back=True)
|
793 |
+
|
794 |
+
# create a sun light
|
795 |
+
bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
|
796 |
+
light = bpy.context.object
|
797 |
+
light.data.energy = 5
|
798 |
+
# angle pointing to the origin
|
799 |
+
light.rotation_euler = (50/180*np.pi, 0, -20/180*np.pi)
|
800 |
+
# set angle
|
801 |
+
light.data.angle = 12/180*np.pi
|
802 |
+
|
803 |
+
# create a camera
|
804 |
+
cam = bpy.data.cameras.new("Camera")
|
805 |
+
cam_ob = bpy.data.objects.new("Camera", cam)
|
806 |
+
camera = bpy.data.objects['Camera']
|
807 |
+
bpy.context.scene.collection.objects.link(camera)
|
808 |
+
camera.location = Vector((0, -3, 1.3))
|
809 |
+
camera.rotation_euler = Vector((80/180*np.pi, 0, 0))
|
810 |
+
# set size
|
811 |
+
camera.data.sensor_width = 48
|
812 |
+
# set the camera to be active
|
813 |
+
bpy.context.scene.camera = camera
|
814 |
+
|
815 |
+
# render the animation
|
816 |
+
bpy.context.scene.frame_start = 0
|
817 |
+
bpy.context.scene.frame_end = 60
|
818 |
+
|
819 |
+
# make the rendered image square
|
820 |
+
bpy.context.scene.render.resolution_x = 2400
|
821 |
+
bpy.context.scene.render.resolution_y = 1080
|
822 |
+
|
823 |
+
setup_render()
|
824 |
+
|
825 |
+
if quick:
|
826 |
+
# reduce the number of samples
|
827 |
+
bpy.context.scene.cycles.samples = 128
|
828 |
+
bpy.context.scene.cycles.preview_samples = 128
|
829 |
+
bpy.context.scene.cycles.max_bounces = 1
|
830 |
+
bpy.context.scene.cycles.min_bounces = 1
|
831 |
+
bpy.context.scene.cycles.diffuse_bounces = 1
|
832 |
+
bpy.context.scene.cycles.glossy_bounces = 1
|
833 |
+
else:
|
834 |
+
bpy.context.scene.cycles.samples = 1024
|
835 |
+
bpy.context.scene.cycles.preview_samples = 1024
|
836 |
+
bpy.context.scene.cycles.max_bounces = 4
|
837 |
+
bpy.context.scene.cycles.min_bounces = 4
|
838 |
+
bpy.context.scene.cycles.diffuse_bounces = 4
|
839 |
+
bpy.context.scene.cycles.glossy_bounces = 4
|
840 |
+
|
841 |
+
# output path
|
842 |
+
bpy.context.scene.render.filepath = output_path
|
843 |
+
if output_path.endswith('.mp4'):
|
844 |
+
# render a mp4 video
|
845 |
+
bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
|
846 |
+
bpy.context.scene.render.ffmpeg.format = 'MPEG4'
|
847 |
+
bpy.context.scene.render.ffmpeg.codec = 'H264'
|
848 |
+
|
849 |
+
bpy.ops.render.render(animation=True, write_still=True)
|
850 |
+
|
851 |
+
def setup_armature(path, tex=False, save=True):
|
852 |
+
joints_matrix = torch.load(os.path.join(path, 'joints.pt'))
|
853 |
+
connectivity = torch.load(os.path.join(path, 'conns.pt'))
|
854 |
+
skinning_weights = torch.load(os.path.join(path, 'skins.pt'))
|
855 |
+
obj_file_path = os.path.join(path, 'object.obj')
|
856 |
+
|
857 |
+
# bpy.ops.wm.obj_import(filepath=obj_file_path)
|
858 |
+
add_mesh(obj_file_path, tex=tex)
|
859 |
+
mesh_object = bpy.context.selected_objects[0]
|
860 |
+
|
861 |
+
# pack textures
|
862 |
+
bpy.ops.file.pack_all()
|
863 |
+
|
864 |
+
temp = torch.tensor(joints_matrix)[:, 1].clone()
|
865 |
+
joints_matrix[:, 1] = -joints_matrix[:, 2]
|
866 |
+
joints_matrix[:, 2] = temp
|
867 |
+
|
868 |
+
bpy.ops.object.armature_add()
|
869 |
+
armature_obj = bpy.context.object
|
870 |
+
|
871 |
+
|
872 |
+
bpy.ops.object.mode_set(mode='EDIT')
|
873 |
+
bpy.ops.armature.select_all(action='SELECT')
|
874 |
+
bpy.ops.armature.delete()
|
875 |
+
|
876 |
+
world_matrix = Matrix([[1, 0, 0, 0],
|
877 |
+
[0, 1, 0, 0],
|
878 |
+
[0, 0, 1, 0],
|
879 |
+
[0, 0, 0, 1]])
|
880 |
+
armature_obj.matrix_world = world_matrix
|
881 |
+
|
882 |
+
bone_dict = {}
|
883 |
+
|
884 |
+
i_name = 0
|
885 |
+
|
886 |
+
for i in range(len(joints_matrix)):
|
887 |
+
|
888 |
+
if connectivity[i] == i:
|
889 |
+
continue
|
890 |
+
bone_name = str(i_name)
|
891 |
+
bone = armature_obj.data.edit_bones.new(bone_name)
|
892 |
+
bone.head = joints_matrix[connectivity[i]].cpu().numpy()
|
893 |
+
bone.tail = joints_matrix[i].cpu().numpy()
|
894 |
+
bone_dict[bone_name] = bone
|
895 |
+
i_name += 1
|
896 |
+
|
897 |
+
for bone_name, bone in bone_dict.items():
|
898 |
+
# Find parent bone by checking if current bone's head matches any other bone's tail
|
899 |
+
for other_bone_name, other_bone in bone_dict.items():
|
900 |
+
if other_bone != bone and bone.head == other_bone.tail:
|
901 |
+
bone.parent = other_bone
|
902 |
+
break
|
903 |
+
|
904 |
+
assert i_name == skinning_weights.shape[1]
|
905 |
+
|
906 |
+
for i, skinning_weight in enumerate(skinning_weights):
|
907 |
+
# print("skinning_weight", skinning_weight)
|
908 |
+
vertex_index = i
|
909 |
+
for j,weight in enumerate(skinning_weight):
|
910 |
+
bone_name = str(j)
|
911 |
+
bone_weight = float(weight)
|
912 |
+
|
913 |
+
vertex_group_name = f"{bone_name}"
|
914 |
+
vertex_group = mesh_object.vertex_groups.get(vertex_group_name)
|
915 |
+
if vertex_group is None:
|
916 |
+
vertex_group = mesh_object.vertex_groups.new(name=vertex_group_name)
|
917 |
+
vertex_group.add([vertex_index], bone_weight, 'ADD')
|
918 |
+
|
919 |
+
# for obj in bpy.context.scene.objects:
|
920 |
+
# if obj.type == 'MESH':
|
921 |
+
modifier = mesh_object.modifiers.new(name="Armature", type='ARMATURE')
|
922 |
+
modifier.object = armature_obj
|
923 |
+
modifier.use_vertex_groups = True
|
924 |
+
print("Armature modifier added to mesh:", mesh_object.name)
|
925 |
+
|
926 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
927 |
+
if save:
|
928 |
+
bpy.ops.wm.save_as_mainfile(filepath= os.path.join(path, 'blender_output.blend'))
|
929 |
+
|
930 |
+
return armature_obj
|
931 |
+
|
932 |
+
def reload_tensor_skinning(data, bone_name_list):
|
933 |
+
|
934 |
+
# with open(json_file, "r") as f:
|
935 |
+
# skinning_data = json.load(f)
|
936 |
+
|
937 |
+
armature_obj = bpy.data.objects.get("Armature")
|
938 |
+
if not armature_obj:
|
939 |
+
print("Error: Armature object 'Armature' not found.")
|
940 |
+
return
|
941 |
+
|
942 |
+
# 将所有网格对象放置在骨骼对象的子集中
|
943 |
+
count = 0
|
944 |
+
for obj in bpy.context.scene.objects:
|
945 |
+
if obj.type == 'MESH':
|
946 |
+
obj.parent = armature_obj
|
947 |
+
count += 1
|
948 |
+
|
949 |
+
print("total mesh count:", count)
|
950 |
+
|
951 |
+
for obj in bpy.context.scene.objects:
|
952 |
+
vertex_index = 0
|
953 |
+
if obj.type == 'MESH':
|
954 |
+
# mesh_name = obj.name
|
955 |
+
# if mesh_name in skinning_data:
|
956 |
+
# skinning_info = skinning_data[mesh_name]
|
957 |
+
# if "weight" in skinning_info:
|
958 |
+
# print("Applying skinning data for mesh:", mesh_name)
|
959 |
+
# vertex_index = 0
|
960 |
+
# for vertex_weight in skinning_info["weight"]:
|
961 |
+
# for bone_name, weight_value in vertex_weight.items():
|
962 |
+
# vertex_group = obj.vertex_groups.get(bone_name)
|
963 |
+
# if vertex_group is None:
|
964 |
+
# vertex_group = obj.vertex_groups.new(name=bone_name)
|
965 |
+
# print("Vertex group created:", bone_name)
|
966 |
+
# vertex_group.add([vertex_index], weight_value, 'REPLACE')
|
967 |
+
# vertex_index += 1
|
968 |
+
# else:
|
969 |
+
# print("No skinning data found for mesh:", mesh_name)
|
970 |
+
|
971 |
+
for i, v in enumerate(obj.data.vertices):
|
972 |
+
v_co = np.array(v.co)
|
973 |
+
pc = data['pc'][:, :3].numpy()
|
974 |
+
y_max = pc[:, 1].max()
|
975 |
+
pc = pc + np.array([0, y_max, 0])
|
976 |
+
pc = pc / 2
|
977 |
+
dist = np.linalg.norm(pc - v_co, axis=1)
|
978 |
+
# min_idx = np.argmin(dist)
|
979 |
+
# sort, and then get top 3 index
|
980 |
+
min_idx_list = np.argsort(dist)[:3]
|
981 |
+
|
982 |
+
for min_idx in min_idx_list:
|
983 |
+
# get inverse distance weight
|
984 |
+
interpolate_weight = np.square(1 / dist[min_idx]) / np.square(1 / dist[min_idx_list]).sum()
|
985 |
+
|
986 |
+
for idx, j in enumerate(data['skins_index'][min_idx]):
|
987 |
+
if j == -1:
|
988 |
+
break
|
989 |
+
bone_name = bone_name_list[j]
|
990 |
+
vertex_group = obj.vertex_groups.get(str(int(bone_name)))
|
991 |
+
if vertex_group is None:
|
992 |
+
vertex_group = obj.vertex_groups.new(name=str(int(bone_name)))
|
993 |
+
print("Vertex group created:", bone_name)
|
994 |
+
|
995 |
+
vertex_group.add([i], interpolate_weight * data['skins_weight'][min_idx][idx], 'ADD')
|
996 |
+
|
997 |
+
|
998 |
+
for obj in bpy.context.scene.objects:
|
999 |
+
if obj.type == 'MESH':
|
1000 |
+
modifier = obj.modifiers.new(name="Armature", type='ARMATURE')
|
1001 |
+
modifier.object = armature_obj
|
1002 |
+
modifier.use_vertex_groups = True
|
1003 |
+
print("Armature modifier added to mesh:", obj.name)
|
1004 |
+
|
1005 |
+
def reload_tensor(data, root='data', save=True):
|
1006 |
+
joints_matrix = data['joints'].clone()
|
1007 |
+
connectivity = data['conns']
|
1008 |
+
obj_file_path = os.path.join(root, data['name'], 'object.obj')
|
1009 |
+
|
1010 |
+
# bpy.ops.wm.obj_import(filepath=obj_file_path)
|
1011 |
+
add_mesh(obj_file_path)
|
1012 |
+
mesh_object = bpy.context.selected_objects[0]
|
1013 |
+
|
1014 |
+
# pack textures
|
1015 |
+
bpy.ops.file.pack_all()
|
1016 |
+
|
1017 |
+
y_max = data['pc'][:, 1].max()
|
1018 |
+
joints_matrix = joints_matrix + torch.tensor([0, y_max, 0])
|
1019 |
+
joints_matrix = joints_matrix / 2
|
1020 |
+
|
1021 |
+
temp = joints_matrix[:, 1].clone()
|
1022 |
+
joints_matrix[:, 1] = -joints_matrix[:, 2]
|
1023 |
+
joints_matrix[:, 2] = temp
|
1024 |
+
|
1025 |
+
bpy.ops.object.armature_add()
|
1026 |
+
armature_obj = bpy.context.object
|
1027 |
+
|
1028 |
+
|
1029 |
+
bpy.ops.object.mode_set(mode='EDIT')
|
1030 |
+
bpy.ops.armature.select_all(action='SELECT')
|
1031 |
+
bpy.ops.armature.delete()
|
1032 |
+
|
1033 |
+
world_matrix = Matrix([[1, 0, 0, 0],
|
1034 |
+
[0, 1, 0, 0],
|
1035 |
+
[0, 0, 1, 0],
|
1036 |
+
[0, 0, 0, 1]])
|
1037 |
+
armature_obj.matrix_world = world_matrix
|
1038 |
+
|
1039 |
+
bone_dict = {}
|
1040 |
+
bone_name_list = np.zeros(data['bones_num'])
|
1041 |
+
i_name = 0
|
1042 |
+
|
1043 |
+
for i in range(len(joints_matrix)):
|
1044 |
+
|
1045 |
+
if connectivity[i] == i:
|
1046 |
+
continue
|
1047 |
+
bone_name = str(i_name)
|
1048 |
+
bone = armature_obj.data.edit_bones.new(bone_name)
|
1049 |
+
bone.head = joints_matrix[connectivity[i]].cpu().numpy()
|
1050 |
+
bone.tail = joints_matrix[i].cpu().numpy()
|
1051 |
+
bone_dict[bone_name] = bone
|
1052 |
+
for j, skinbone in enumerate(data['bones']):
|
1053 |
+
if torch.equal(skinbone[:3], data['joints'][connectivity[i]]) and torch.equal(skinbone[3:], data['joints'][i]):
|
1054 |
+
bone_name_list[j] = i_name
|
1055 |
+
i_name += 1
|
1056 |
+
|
1057 |
+
for bone_name, bone in bone_dict.items():
|
1058 |
+
# Find parent bone by checking if current bone's head matches any other bone's tail
|
1059 |
+
for other_bone_name, other_bone in bone_dict.items():
|
1060 |
+
if other_bone != bone and bone.head == other_bone.tail:
|
1061 |
+
bone.parent = other_bone
|
1062 |
+
break
|
1063 |
+
|
1064 |
+
print(bone_name_list)
|
1065 |
+
|
1066 |
+
reload_tensor_skinning(data, bone_name_list)
|
1067 |
+
|
1068 |
+
print("Armature modifier added to mesh:", mesh_object.name)
|
1069 |
+
|
1070 |
+
bpy.ops.object.mode_set(mode='OBJECT')
|
1071 |
+
if save:
|
1072 |
+
bpy.ops.wm.save_as_mainfile(filepath= os.path.join('/data2/ydengbd/Anymate/Anymate/data', data['name'], 'blender_output.blend'))
|
1073 |
+
|
1074 |
+
return armature_obj
|
1075 |
+
|
1076 |
+
def load_blender(blender_path):
|
1077 |
+
|
1078 |
+
bpy.ops.wm.read_homefile(use_empty=True)
|
1079 |
+
# bpy.ops.wm.append(directory=object_path, link=False)
|
1080 |
+
# load_object(object_path)
|
1081 |
+
bpy.ops.wm.open_mainfile(filepath=blender_path)
|
1082 |
+
armature_obj = []
|
1083 |
+
mesh_obj = []
|
1084 |
+
for obj in bpy.context.scene.objects:
|
1085 |
+
if obj.type == "ARMATURE":
|
1086 |
+
armature_obj.append(obj)
|
1087 |
+
if obj.type == "MESH":
|
1088 |
+
mesh_obj.append(obj)
|
1089 |
+
|
1090 |
+
print('mesh obj:', len(mesh_obj))
|
1091 |
+
|
1092 |
+
|
1093 |
+
|
1094 |
+
# start retrieve the information of mesh, skining and rigging
|
1095 |
+
|
1096 |
+
#1. retrieve the information of rigging, save the world matrix of the amature object
|
1097 |
+
total_armature_info = {}
|
1098 |
+
joints_matrix = []
|
1099 |
+
bone_dict = {}
|
1100 |
+
parent_name= []
|
1101 |
+
bone_count = 0
|
1102 |
+
for obj in armature_obj:
|
1103 |
+
# depsgraph = bpy.context.evaluated_depsgraph_get()
|
1104 |
+
# obj = obj.evaluated_get(depsgraph)
|
1105 |
+
armature_info = {}
|
1106 |
+
armature_info["world_matrix"] = [list(row) for row in obj.matrix_world.copy()]
|
1107 |
+
translation = obj.matrix_world.translation
|
1108 |
+
for bone in obj.pose.bones:
|
1109 |
+
|
1110 |
+
joints_matrix.append(np.array(list((obj.matrix_world.to_3x3() @ bone.head+translation).copy())))
|
1111 |
+
|
1112 |
+
if bone.parent:
|
1113 |
+
parent_name.append(bone.parent.name)
|
1114 |
+
else:
|
1115 |
+
parent_name.append('root')
|
1116 |
+
bone_dict[bone.name] = bone_count
|
1117 |
+
bone_count += 1
|
1118 |
+
connectivity = torch.zeros(bone_count, dtype=torch.int32)
|
1119 |
+
|
1120 |
+
for i, bone_name in enumerate(parent_name):
|
1121 |
+
if bone_name == 'root':
|
1122 |
+
connectivity[i] = i
|
1123 |
+
else:
|
1124 |
+
connectivity[i] = bone_dict[bone_name]
|
1125 |
+
joints_matrix = torch.from_numpy(np.array(joints_matrix))
|
1126 |
+
|
1127 |
+
skinning_weight = torch.zeros(len(mesh_obj[0].data.vertices), joints_matrix.shape[0])
|
1128 |
+
|
1129 |
+
vertex_index = 0
|
1130 |
+
for obj in mesh_obj:
|
1131 |
+
vertex_groups = obj.vertex_groups
|
1132 |
+
|
1133 |
+
|
1134 |
+
for vertex in obj.data.vertices:
|
1135 |
+
vertex_info = {}
|
1136 |
+
for group in vertex.groups:
|
1137 |
+
name = vertex_groups[group.group].name
|
1138 |
+
|
1139 |
+
weight = group.weight
|
1140 |
+
skinning_weight[vertex.index][bone_dict[name]] = weight
|
1141 |
+
|
1142 |
+
obj_save_path = blender_path.replace('.blend', '.obj')
|
1143 |
+
bpy.ops.wm.obj_export(filepath=obj_save_path, export_materials=False)
|
1144 |
+
return joints_matrix,connectivity, skinning_weight
|
1145 |
+
|
1146 |
+
|
1147 |
+
def save_scene(scene_path):
|
1148 |
+
# export the scene as a glb file
|
1149 |
+
if scene_path.endswith('.glb'):
|
1150 |
+
bpy.ops.export_scene.gltf(filepath=scene_path)
|
1151 |
+
bpy.ops.wm.save_as_mainfile(filepath=scene_path.replace('.glb', '.blend'))
|
1152 |
+
elif scene_path.endswith('.blend'):
|
1153 |
+
bpy.ops.wm.save_as_mainfile(filepath=scene_path)
|
1154 |
+
elif scene_path.endswith('.obj'):
|
1155 |
+
bpy.ops.wm.obj_export(filepath=scene_path, export_materials=False)
|
1156 |
+
else:
|
1157 |
+
raise ValueError(f"Unsupported file extension: {scene_path}")
|
1158 |
+
|
1159 |
+
if __name__ == '__main__':
|
1160 |
+
# load the mesh
|
1161 |
+
empty()
|
1162 |
+
add_mesh('/home/ydengbd/objaverse/obj/0001.obj')
|
1163 |
+
# load the joints
|
1164 |
+
joints_matrix = np.load('/home/ydengbd/objaverse/joints/0001.npy')
|
1165 |
+
add_joint(joints_matrix)
|
1166 |
+
# load the connections
|
1167 |
+
con_index = np.load('/home/ydengbd/objaverse/connections/0001.npy')
|
1168 |
+
add_conn(con_index)
|
1169 |
+
# load the skin
|
Anymate/utils/train_utils.py
ADDED
@@ -0,0 +1,406 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import numpy as np
|
3 |
+
from tqdm import tqdm
|
4 |
+
|
5 |
+
import torch
|
6 |
+
import torch.backends.cudnn as cudnn
|
7 |
+
from torch.utils.tensorboard import SummaryWriter
|
8 |
+
from torch.utils.data import DataLoader
|
9 |
+
|
10 |
+
from Anymate.dataset import AnymateDataset, my_collate
|
11 |
+
from Anymate.model import EncoderDecoder
|
12 |
+
from Anymate.utils.loss_utils import cross_entropy_with_probs_batch, cos_loss, cos_loss_clamp, chamfer_distance_with_average
|
13 |
+
from Anymate.utils.vol_utils import get_co, get_gt, extract_keypoints
|
14 |
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
15 |
+
from torch.distributed import init_process_group, destroy_process_group
|
16 |
+
from torch.utils.data.distributed import DistributedSampler
|
17 |
+
|
18 |
+
import point_cloud_utils as pcu
|
19 |
+
from sklearn.cluster import DBSCAN
|
20 |
+
from diffusers import DDPMScheduler, DDIMScheduler
|
21 |
+
import torch.nn.functional as F
|
22 |
+
from Anymate.utils.diffusion_utils import my_collate_diff, randn_tensor
|
23 |
+
|
24 |
+
|
25 |
+
def ddp_setup(rank: int, world_size: int, port: int):
|
26 |
+
"""
|
27 |
+
Args:
|
28 |
+
rank: Unique identifier of each process
|
29 |
+
world_size: Total number of processes
|
30 |
+
"""
|
31 |
+
os.environ["MASTER_ADDR"] = "localhost"
|
32 |
+
os.environ["MASTER_PORT"] = str(port)
|
33 |
+
torch.cuda.set_device(rank)
|
34 |
+
init_process_group(backend="nccl", rank=rank, world_size=world_size)
|
35 |
+
|
36 |
+
class AverageMeter(object):
|
37 |
+
"""Computes and stores the average and current value"""
|
38 |
+
def __init__(self):
|
39 |
+
self.reset()
|
40 |
+
|
41 |
+
def reset(self):
|
42 |
+
self.val = 0.0
|
43 |
+
self.avg = 0.0
|
44 |
+
self.sum = 0.0
|
45 |
+
self.count = 0.0
|
46 |
+
|
47 |
+
def update(self, val, n=1):
|
48 |
+
self.val = val
|
49 |
+
self.sum += val * n
|
50 |
+
self.count += n
|
51 |
+
self.avg = self.sum / self.count
|
52 |
+
|
53 |
+
def accumulate(self, val, n=1):
|
54 |
+
self.val = val
|
55 |
+
self.sum += val
|
56 |
+
self.count += n
|
57 |
+
self.avg = self.sum / self.count
|
58 |
+
|
59 |
+
def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='model_best.pth.tar', snapshot=None):
|
60 |
+
filepath = os.path.join(checkpoint, filename)
|
61 |
+
if is_best:
|
62 |
+
torch.save(state, filepath)
|
63 |
+
|
64 |
+
if snapshot and state['epoch'] % snapshot == 0:
|
65 |
+
torch.save(state, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch'])))
|
66 |
+
|
67 |
+
def train_model(rank, world_size, config, args, shared_dict, port=12355):
|
68 |
+
ddp_setup(rank, world_size, port)
|
69 |
+
lowest_loss = 1e20
|
70 |
+
model_config = config['model']
|
71 |
+
model = EncoderDecoder(device=f'cuda:{rank}', dtype=torch.float32, **model_config)
|
72 |
+
model.to(f'cuda:{rank}')
|
73 |
+
|
74 |
+
if rank == 0:
|
75 |
+
print('only_embed', model.only_embed)
|
76 |
+
print('return_latents', model.return_latents)
|
77 |
+
print(model)
|
78 |
+
if not args.finetune:
|
79 |
+
model.encoder.requires_grad_(False)
|
80 |
+
model = DDP(model, device_ids=[rank])
|
81 |
+
optimizer_config = config['optimizer']
|
82 |
+
if args.finetune:
|
83 |
+
optimizer = torch.optim.Adam(model.module.parameters(), **optimizer_config)
|
84 |
+
else:
|
85 |
+
if args.encoder == 'miche':
|
86 |
+
optimizer = torch.optim.Adam(model.module.decoder.parameters(), **optimizer_config)
|
87 |
+
elif args.encoder == 'bert':
|
88 |
+
optimizer = torch.optim.Adam(list(model.module.decoder.parameters()) + list(model.module.point_proj.parameters()), **optimizer_config)
|
89 |
+
# optionally resume from a checkpoint
|
90 |
+
if args.resume:
|
91 |
+
try:
|
92 |
+
print("=> loading checkpoint '{}'".format(args.resume))
|
93 |
+
checkpoint = torch.load(args.resume)
|
94 |
+
args.start_epoch = checkpoint['epoch']
|
95 |
+
lowest_loss = checkpoint['lowest_loss']
|
96 |
+
model.module.load_state_dict(checkpoint['state_dict'], strict=True)
|
97 |
+
|
98 |
+
print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
|
99 |
+
except:
|
100 |
+
print("=> no checkpoint found at '{}'".format(args.resume))
|
101 |
+
|
102 |
+
cudnn.benchmark = True
|
103 |
+
print(' Total params: %.2fM' % (sum(p.numel() for p in optimizer.param_groups[0]['params']) / 1000000.0))
|
104 |
+
my_collate_func = my_collate_diff if args.mode == 'diffusion' else my_collate
|
105 |
+
if world_size > 1:
|
106 |
+
if not args.split:
|
107 |
+
train_dataset = shared_dict['train_dataset']
|
108 |
+
train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
|
109 |
+
train_loader = DataLoader(train_dataset, batch_size=args.train_batch, sampler=train_sampler, collate_fn= my_collate_func)
|
110 |
+
else:
|
111 |
+
train_dataset = AnymateDataset(name=args.trainset + f'_{rank}', root=args.root) #should changed to dpp version
|
112 |
+
train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, collate_fn= my_collate_func)
|
113 |
+
else:
|
114 |
+
train_dataset = AnymateDataset(name=args.trainset, root=args.root)
|
115 |
+
train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, collate_fn= my_collate_func)
|
116 |
+
|
117 |
+
if rank == 0:
|
118 |
+
test_loader = DataLoader(AnymateDataset(name=args.testset, root=args.root), batch_size=args.test_batch, shuffle=False, collate_fn= my_collate_func )
|
119 |
+
|
120 |
+
if not args.schedule:
|
121 |
+
args.schedule = [args.epochs//2]
|
122 |
+
scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=args.gamma)
|
123 |
+
# step the scheduler to the start epoch
|
124 |
+
for _ in range(args.start_epoch):
|
125 |
+
scheduler.step()
|
126 |
+
if rank == 0:
|
127 |
+
logger = SummaryWriter(log_dir=args.logdir)
|
128 |
+
print('start ')
|
129 |
+
print('test_frequency', args.test_freq)
|
130 |
+
print('start from epoch', args.start_epoch)
|
131 |
+
# start training
|
132 |
+
for epoch in range(args.start_epoch, args.epochs):
|
133 |
+
test_dict = None
|
134 |
+
is_best = False
|
135 |
+
lr = scheduler.get_last_lr()
|
136 |
+
if rank == 0:
|
137 |
+
print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr[0]))
|
138 |
+
train_loss, grad_norm = train(train_loader, model, optimizer, args)
|
139 |
+
if rank == 0 and (epoch == 0 or (epoch+1)%args.test_freq== 0):
|
140 |
+
print('Testing epoch', epoch+1)
|
141 |
+
test_dict = test(test_loader, model, args, world_size=world_size)
|
142 |
+
|
143 |
+
|
144 |
+
scheduler.step()
|
145 |
+
if rank == 0:
|
146 |
+
print('Epoch{:d}. train_loss: {:.6f}.'.format(epoch + 1, train_loss))
|
147 |
+
print('Epoch{:d}. grad_norm: {:.6f}.'.format(epoch + 1, grad_norm))
|
148 |
+
info = {'train_loss': train_loss, 'grad_norm': grad_norm, 'lr': lr[0]}
|
149 |
+
# print('Epoch{:d}. val_loss: {:.6f}.'.format(epoch + 1, val_loss))
|
150 |
+
if test_dict is not None:
|
151 |
+
for key, value in test_dict.items():
|
152 |
+
print('Epoch{:d}. {:s}: {:.6f}.'.format(epoch + 1, key, value))
|
153 |
+
|
154 |
+
test_loss = test_dict['test loss'] if not args.mode == 'diffusion' else test_dict['chamfer']
|
155 |
+
is_best = test_loss < lowest_loss
|
156 |
+
lowest_loss = min(test_loss, lowest_loss)
|
157 |
+
for key, value in test_dict.items():
|
158 |
+
info[key] = value
|
159 |
+
|
160 |
+
for tag, value in info.items():
|
161 |
+
logger.add_scalar(tag, value, epoch+1)
|
162 |
+
save_dict = {'epoch': epoch + 1, 'state_dict': model.module.state_dict(), 'lowest_loss': lowest_loss, 'optimizer': optimizer.state_dict(), 'model_config': model_config}
|
163 |
+
save_checkpoint(save_dict, is_best=is_best, checkpoint=args.checkpoint, snapshot=args.epochs//20)
|
164 |
+
|
165 |
+
def get_criterion(args):
|
166 |
+
if args.loss == 'cos':
|
167 |
+
criterion = cos_loss
|
168 |
+
elif args.loss == 'ce':
|
169 |
+
criterion = cross_entropy_with_probs_batch
|
170 |
+
elif args.loss == 'cos_clamp':
|
171 |
+
criterion = cos_loss_clamp
|
172 |
+
else:
|
173 |
+
criterion = chamfer_distance_with_average
|
174 |
+
return criterion
|
175 |
+
|
176 |
+
def get_train_loss(model, data, args):
|
177 |
+
criterion = get_criterion(args)
|
178 |
+
loss = 0.0
|
179 |
+
if args.mode == 'skin':
|
180 |
+
y_pred, idx = model(data, downsample=1024)
|
181 |
+
y_pred = torch.softmax(y_pred, dim=-1)
|
182 |
+
y = data['skins'].to(args.device)
|
183 |
+
y = y[:, idx]
|
184 |
+
loss = criterion(y_pred, y)
|
185 |
+
|
186 |
+
elif args.mode == 'conn':
|
187 |
+
y_pred = model(data, args.device)
|
188 |
+
y_pred = torch.softmax(y_pred, dim=-1)
|
189 |
+
y = data['conns'].to(args.device)
|
190 |
+
y = y[:, :y_pred.shape[1], :y_pred.shape[1]].float()
|
191 |
+
loss = criterion(y_pred, y)
|
192 |
+
|
193 |
+
elif args.mode == 'joints': # joints mode
|
194 |
+
if args.decoder == 'transformer_latent':
|
195 |
+
y_pred = model(data, args.device)
|
196 |
+
joints_gt = data['joints'].to(args.device)
|
197 |
+
loss = 0.0
|
198 |
+
for i in range(joints_gt.shape[0]):
|
199 |
+
joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
|
200 |
+
loss += criterion(y_pred[i:i+1], joints_gt_i.unsqueeze(0))
|
201 |
+
loss /= joints_gt.shape[0]
|
202 |
+
|
203 |
+
elif args.decoder == 'triplane' or args.decoder == 'implicit_transformer':
|
204 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
205 |
+
y_pred = model(data, args.device, downsample=True)
|
206 |
+
joints_gt = data['joints'].to(args.device)
|
207 |
+
for i in range(joints_gt.shape[0]):
|
208 |
+
joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
|
209 |
+
vol = get_co(data['vox'][i])
|
210 |
+
if data['vox'][i].shape[0] > 50000:
|
211 |
+
vol = vol[y_pred[i][1]]
|
212 |
+
gt = get_gt(vol.to(args.device), joints_gt_i)
|
213 |
+
loss += criterion(y_pred[i][0].squeeze(-1).unsqueeze(0), gt.unsqueeze(0))
|
214 |
+
else:
|
215 |
+
gt = get_gt(vol.to(args.device), joints_gt_i)
|
216 |
+
loss += criterion(y_pred[i].squeeze(-1).unsqueeze(0), gt.unsqueeze(0))
|
217 |
+
loss /= joints_gt.shape[0]
|
218 |
+
|
219 |
+
elif args.mode == 'diffusion':
|
220 |
+
noise_scheduler = DDIMScheduler(num_train_timesteps=args.num_train_step)
|
221 |
+
|
222 |
+
samples = data['joints_repeat'].to(model.device).float()
|
223 |
+
#use 256 input joints
|
224 |
+
samples = samples[...,:args.num_training_points,:]
|
225 |
+
|
226 |
+
samples = samples.to(model.device)
|
227 |
+
noise = torch.randn(samples.shape, device=samples.device)
|
228 |
+
assert samples.device == noise.device
|
229 |
+
bs = samples.shape[0]
|
230 |
+
|
231 |
+
# Sample a random timestep for each image
|
232 |
+
timesteps = torch.randint(
|
233 |
+
0, noise_scheduler.config.num_train_timesteps, (bs,), device=samples.device,
|
234 |
+
dtype=torch.int64
|
235 |
+
)
|
236 |
+
|
237 |
+
noisy_joints = noise_scheduler.add_noise(samples, noise, timesteps)
|
238 |
+
noisy_joints = noisy_joints.to(model.device)
|
239 |
+
noisy_joints = noisy_joints.permute(0, 2, 1)
|
240 |
+
|
241 |
+
noise_pred = model(data, noisy_joints=noisy_joints, timesteps = timesteps)
|
242 |
+
noise_pred = noise_pred.permute(0, 2, 1)
|
243 |
+
loss = F.mse_loss(noise_pred, noise)
|
244 |
+
|
245 |
+
return loss
|
246 |
+
|
247 |
+
def train(train_loader, model, optimizer, args):
|
248 |
+
if not args.finetune:
|
249 |
+
model.train()
|
250 |
+
model.module.encoder.eval()
|
251 |
+
else:
|
252 |
+
model.train()
|
253 |
+
loss_meter = AverageMeter()
|
254 |
+
grad_norm_meter = AverageMeter()
|
255 |
+
|
256 |
+
for data in tqdm(train_loader):
|
257 |
+
loss = get_train_loss(model, data, args)
|
258 |
+
optimizer.zero_grad()
|
259 |
+
loss.backward()
|
260 |
+
grad_norm = 0
|
261 |
+
|
262 |
+
for p in optimizer.param_groups[0]['params']:
|
263 |
+
grad_norm += p.grad.data.norm(2).item()
|
264 |
+
grad_norm_meter.update(grad_norm)
|
265 |
+
optimizer.step()
|
266 |
+
loss_meter.update(loss.item())
|
267 |
+
|
268 |
+
return loss_meter.avg, grad_norm_meter.avg
|
269 |
+
|
270 |
+
def test(test_loader, model, args, world_size=1):
|
271 |
+
model.eval()
|
272 |
+
assert args.mode in ['skin', 'joints', 'conn', 'diffusion'], 'mode should be choose from [skin, joints, conn, diffusion], got {}'.format(args.mode)
|
273 |
+
|
274 |
+
if args.mode == 'skin' or args.mode == 'conn':
|
275 |
+
loss_meter = AverageMeter()
|
276 |
+
cos_sim_meter = AverageMeter()
|
277 |
+
cos_clamp_meter = AverageMeter()
|
278 |
+
for i, data in enumerate(tqdm(test_loader)):
|
279 |
+
if world_size > 1 and i > 1000:
|
280 |
+
break
|
281 |
+
with torch.no_grad():
|
282 |
+
y_pred = model(data, args.device)
|
283 |
+
y_pred = torch.softmax(y_pred, dim=-1)
|
284 |
+
|
285 |
+
if args.mode == 'skin':
|
286 |
+
y = data['skins'].to(args.device)
|
287 |
+
elif args.mode == 'conn':
|
288 |
+
y = data['conns'].to(args.device)
|
289 |
+
y = y[:, :y_pred.shape[1], :y_pred.shape[1]].float()
|
290 |
+
|
291 |
+
loss = 0.0
|
292 |
+
loss = cross_entropy_with_probs_batch(y_pred, y)
|
293 |
+
loss_meter.update(loss.item())
|
294 |
+
cos_sim = cos_loss(y_pred, y)
|
295 |
+
cos_sim_meter.update(cos_sim.mean().item()) # 1 - loss.item()
|
296 |
+
cos_clamp = cos_loss_clamp(y_pred, y)
|
297 |
+
cos_clamp_meter.update(cos_clamp.mean().item())
|
298 |
+
|
299 |
+
loss_dict = {'test loss': loss_meter.avg, 'cos_sim': cos_sim_meter.avg, 'cos_clamp': cos_clamp_meter.avg}
|
300 |
+
# get the loss of the joints prediction
|
301 |
+
elif args.mode == 'joints':
|
302 |
+
if args.decoder == 'transformer_latent':
|
303 |
+
loss_meter = AverageMeter()
|
304 |
+
emd_meter = AverageMeter()
|
305 |
+
for i, data in tqdm(enumerate(test_loader)):
|
306 |
+
if world_size > 1 and i > 1000:
|
307 |
+
break
|
308 |
+
with torch.no_grad():
|
309 |
+
y_pred = model(data, args.device)
|
310 |
+
joints_gt = data['joints'].to(args.device)
|
311 |
+
|
312 |
+
loss = 0.0
|
313 |
+
emd = 0.0
|
314 |
+
for i in range(joints_gt.shape[0]):
|
315 |
+
joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
|
316 |
+
y_pred_i = y_pred[i]
|
317 |
+
|
318 |
+
y_pred_i = y_pred[i].detach().cpu().numpy()
|
319 |
+
clustering = DBSCAN(eps=0.03, min_samples=1).fit(y_pred_i) # Consider add eps and min_samples as arguments
|
320 |
+
cluster_centers = []
|
321 |
+
for cluster in set(clustering.labels_):
|
322 |
+
cluster_centers.append(y_pred_i[clustering.labels_ == cluster].mean(axis=0))
|
323 |
+
y_pred_i = torch.from_numpy(np.array(cluster_centers)).to(args.device)
|
324 |
+
|
325 |
+
if y_pred_i.shape[0] < 2:
|
326 |
+
print(data['name'][i] + ' has less than 2 points')
|
327 |
+
continue
|
328 |
+
loss += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joints_gt_i.unsqueeze(0))
|
329 |
+
emd_i, pi = pcu.earth_movers_distance(y_pred_i.cpu().numpy().astype(np.float64), joints_gt_i.cpu().numpy().astype(np.float64))
|
330 |
+
emd += emd_i
|
331 |
+
if loss == 0 or emd == 0:
|
332 |
+
continue
|
333 |
+
loss /= joints_gt.shape[0]
|
334 |
+
loss_meter.update(loss.item())
|
335 |
+
emd_meter.update(emd)
|
336 |
+
loss_dict = {'test loss': loss_meter.avg, 'emd': emd_meter.avg}
|
337 |
+
|
338 |
+
elif args.decoder == 'triplane' or 'implicit_transformer':
|
339 |
+
loss_meter = AverageMeter()
|
340 |
+
emd_meter = AverageMeter()
|
341 |
+
chamfer_meter = AverageMeter()
|
342 |
+
criterion = torch.nn.BCEWithLogitsLoss()
|
343 |
+
for data in tqdm(test_loader):
|
344 |
+
with torch.no_grad():
|
345 |
+
y_pred = model(data, args.device)
|
346 |
+
joints_gt = data['joints'].to(args.device)
|
347 |
+
loss = 0.0
|
348 |
+
emd = 0.0
|
349 |
+
chamfer = 0.0
|
350 |
+
for i in range(joints_gt.shape[0]):
|
351 |
+
joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
|
352 |
+
vol = get_co(data['vox'][i])
|
353 |
+
gt = get_gt(vol.to(args.device), joints_gt_i)
|
354 |
+
loss += criterion(y_pred[i].squeeze(-1).unsqueeze(0), gt.unsqueeze(0))
|
355 |
+
key_points = extract_keypoints(y_pred[i].cpu(), data['vox'][i])
|
356 |
+
if len(key_points) < 2:
|
357 |
+
continue
|
358 |
+
key_points = key_points / 32 - 1
|
359 |
+
chamfer += chamfer_distance_with_average(torch.from_numpy(key_points).unsqueeze(0).to(joints_gt_i.device), joints_gt_i.unsqueeze(0))
|
360 |
+
emd_i, _ = pcu.earth_movers_distance(key_points.astype(np.float64), joints_gt_i.cpu().numpy().astype(np.float64))
|
361 |
+
emd += emd_i
|
362 |
+
if loss == 0 or emd == 0 or chamfer == 0:
|
363 |
+
continue
|
364 |
+
loss /= joints_gt.shape[0]
|
365 |
+
loss_meter.update(loss.item())
|
366 |
+
emd_meter.update(emd)
|
367 |
+
chamfer_meter.update(chamfer.item())
|
368 |
+
loss_dict = {'test loss': loss_meter.avg, 'emd': emd_meter.avg, 'chamfer': chamfer_meter.avg}
|
369 |
+
|
370 |
+
elif args.mode == 'diffusion':
|
371 |
+
loss_meter = AverageMeter()
|
372 |
+
emd_meter = AverageMeter()
|
373 |
+
chamfer_meter = AverageMeter()
|
374 |
+
generator=torch.Generator(device='cpu').manual_seed(args.seed+1)
|
375 |
+
scheduler = DDIMScheduler(num_train_timesteps=args.num_train_step)
|
376 |
+
scheduler.set_timesteps(args.num_train_step)
|
377 |
+
points_shape = [args.test_batch, args.num_training_points, 3]
|
378 |
+
for data in tqdm(test_loader):
|
379 |
+
joints_gt = data['joints'].to(dtype=torch.float64)
|
380 |
+
points_noise = randn_tensor(points_shape, generator=generator)
|
381 |
+
points = points_noise.permute(0, 2, 1).to(model.device)
|
382 |
+
for t in scheduler.timesteps:
|
383 |
+
with torch.no_grad():
|
384 |
+
time_steps = torch.ones(args.test_batch, 1, dtype=torch.long) * t
|
385 |
+
time_steps = time_steps.to(model.device)
|
386 |
+
model_output = model(data, noisy_joints=points, timesteps = time_steps)
|
387 |
+
|
388 |
+
points = scheduler.step(model_output, t, points, generator=generator).prev_sample
|
389 |
+
points = points.permute(0, 2, 1).cpu()
|
390 |
+
|
391 |
+
chamfer_sum = 0.0
|
392 |
+
emd_sum = 0.0
|
393 |
+
|
394 |
+
for i in range(args.test_batch):
|
395 |
+
joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
|
396 |
+
points_i = points[i]
|
397 |
+
points_i = points_i.reshape( -1, 3)
|
398 |
+
emd, p = pcu.earth_movers_distance(points_i.cpu().numpy(),joints_gt_i[:,:3].cpu().numpy())
|
399 |
+
emd_sum += emd
|
400 |
+
chamfer_sum += chamfer_distance_with_average(points_i.unsqueeze(0),joints_gt_i[:,:3].unsqueeze(0))
|
401 |
+
|
402 |
+
emd_meter.update(emd_sum)
|
403 |
+
chamfer_meter.update(chamfer_sum.item())
|
404 |
+
loss_dict = {'chamfer': chamfer_meter.avg, 'emd': emd_meter.avg}
|
405 |
+
|
406 |
+
return loss_dict
|
Anymate/utils/ui_utils.py
ADDED
@@ -0,0 +1,284 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import trimesh
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
import matplotlib.pyplot as plt
|
6 |
+
import gradio as gr
|
7 |
+
import time
|
8 |
+
bone_colors = plt.get_cmap('tab10')
|
9 |
+
|
10 |
+
from Anymate.utils.utils import load_checkpoint, get_joint, get_connectivity, get_skinning
|
11 |
+
from Anymate.utils.dataset_utils import obj2mesh
|
12 |
+
from Anymate.args import anymate_args
|
13 |
+
# from Anymate.utils.render_utils import empty, add_co, add_mesh, add_joint, add_conn, add_skin, setup_armature
|
14 |
+
|
15 |
+
def visualize_results(mesh_file=None, joints=None, conns=None, skins=None):
|
16 |
+
# Create a scene with both original and processed meshes
|
17 |
+
scene = trimesh.Scene()
|
18 |
+
vis_file = mesh_file.replace('object.obj', 'vis.glb')
|
19 |
+
|
20 |
+
if mesh_file is not None:
|
21 |
+
# Load the original mesh (in blue) with transparency
|
22 |
+
# original_mesh = trimesh.load(mesh_file)
|
23 |
+
original_mesh = obj2mesh(mesh_file)
|
24 |
+
if skins is not None:
|
25 |
+
# pdb.set_trace()
|
26 |
+
# Get per-vertex colors based on skinning weights
|
27 |
+
vertex_colors = np.zeros((len(original_mesh.vertices), 4))
|
28 |
+
|
29 |
+
# Convert skinning weights to numpy if needed
|
30 |
+
if isinstance(skins, torch.Tensor):
|
31 |
+
skins = skins.cpu().numpy()
|
32 |
+
|
33 |
+
# For each bone, blend colors based on skinning weights
|
34 |
+
for bone_idx in range(skins.shape[1]):
|
35 |
+
bone_color = np.array(bone_colors(bone_idx % 10)) # Get base color for this bone
|
36 |
+
weights = skins[:, bone_idx]
|
37 |
+
vertex_colors += np.outer(weights, bone_color) # Blend weighted colors
|
38 |
+
|
39 |
+
# Normalize and clip colors
|
40 |
+
vertex_colors = np.clip(vertex_colors, 0, 1)
|
41 |
+
|
42 |
+
# Convert to vertex colors and set alpha
|
43 |
+
vertex_colors = (vertex_colors * 255).astype(np.uint8)
|
44 |
+
vertex_colors[:, 3] = 255 # Set alpha to 100 for transparency
|
45 |
+
# print(vertex_colors.shape)
|
46 |
+
# print(vertex_colors.max(axis=0), vertex_colors.min(axis=0), vertex_colors.mean(axis=0))
|
47 |
+
|
48 |
+
# Apply colors directly to vertices
|
49 |
+
original_mesh.visual.vertex_colors = vertex_colors
|
50 |
+
|
51 |
+
# face_colors = np.zeros((len(original_mesh.faces), 4))
|
52 |
+
|
53 |
+
# processed_mesh = trimesh.load(mesh_file)
|
54 |
+
processed_mesh = obj2mesh(mesh_file)
|
55 |
+
# Assign vertex colors from original_mesh to processed_mesh
|
56 |
+
# Since they might have different number of vertices, we need to find closest vertices
|
57 |
+
|
58 |
+
# Get vertices from both meshes
|
59 |
+
orig_vertices = original_mesh.vertices
|
60 |
+
proc_vertices = processed_mesh.vertices
|
61 |
+
|
62 |
+
# For each vertex in processed_mesh, find the closest vertex in original_mesh
|
63 |
+
closest_indices = []
|
64 |
+
for proc_vertex in proc_vertices:
|
65 |
+
# Calculate distances to all original vertices
|
66 |
+
distances = np.linalg.norm(orig_vertices - proc_vertex, axis=1)
|
67 |
+
# Find index of closest vertex
|
68 |
+
closest_idx = np.argmin(distances)
|
69 |
+
closest_indices.append(closest_idx)
|
70 |
+
|
71 |
+
proc_vertex_colors = original_mesh.visual.vertex_colors[closest_indices]
|
72 |
+
processed_mesh.visual.vertex_colors = proc_vertex_colors
|
73 |
+
original_mesh = processed_mesh
|
74 |
+
|
75 |
+
else:
|
76 |
+
original_mesh.visual.face_colors = [255, 255, 255, 100] # Blue with alpha=100 for transparency
|
77 |
+
scene.add_geometry(original_mesh)
|
78 |
+
|
79 |
+
if joints is not None:
|
80 |
+
# create a sphere for each joint
|
81 |
+
for position in joints:
|
82 |
+
sphere = trimesh.primitives.Sphere(radius=0.02)
|
83 |
+
sphere.visual.face_colors = [255, 0, 0, 255] # Red with transparency
|
84 |
+
sphere.apply_translation(position.cpu().numpy())
|
85 |
+
scene.add_geometry(sphere)
|
86 |
+
|
87 |
+
if conns is not None:
|
88 |
+
# create a line for each connectivity
|
89 |
+
for i, conn in enumerate(conns):
|
90 |
+
if i == conn:
|
91 |
+
continue
|
92 |
+
# Create cylinder between joints
|
93 |
+
points = [joints[i].cpu().numpy(), joints[conn].cpu().numpy()]
|
94 |
+
direction = points[1] - points[0]
|
95 |
+
height = np.linalg.norm(direction)
|
96 |
+
cylinder = trimesh.primitives.Cylinder(radius=0.01, height=height)
|
97 |
+
|
98 |
+
# Calculate rotation matrix to align cylinder with direction
|
99 |
+
direction = direction / height # Normalize direction vector
|
100 |
+
up_vector = np.array([0, 0, 1])
|
101 |
+
rotation_matrix = trimesh.geometry.align_vectors(up_vector, direction)
|
102 |
+
|
103 |
+
# Apply rotation and translation to cylinder
|
104 |
+
cylinder.apply_transform(rotation_matrix)
|
105 |
+
cylinder.apply_translation(points[0] + direction * height/2)
|
106 |
+
|
107 |
+
cylinder.visual.face_colors = [0, 0, 255, 255] # Blue
|
108 |
+
scene.add_geometry(cylinder)
|
109 |
+
|
110 |
+
# Export the scene
|
111 |
+
scene.export(vis_file)
|
112 |
+
return vis_file
|
113 |
+
|
114 |
+
|
115 |
+
def process_mesh_to_pc(obj_path, sample_num = 8192, save_path = None):
|
116 |
+
# mesh_list : list of trimesh
|
117 |
+
try :
|
118 |
+
mesh = trimesh.load_mesh(obj_path)
|
119 |
+
|
120 |
+
points, face_idx = mesh.sample(sample_num, return_index=True)
|
121 |
+
normals = mesh.face_normals[face_idx]
|
122 |
+
|
123 |
+
pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
|
124 |
+
|
125 |
+
|
126 |
+
if save_path is not None:
|
127 |
+
np.save(save_path, pc_normal)
|
128 |
+
|
129 |
+
return pc_normal
|
130 |
+
except Exception as e:
|
131 |
+
print(f"Error: {obj_path} {e}")
|
132 |
+
return None
|
133 |
+
|
134 |
+
|
135 |
+
def normalize_mesh(mesh):
|
136 |
+
# Check if input is a scene with multiple meshes
|
137 |
+
if isinstance(mesh, trimesh.Scene):
|
138 |
+
# Combine all meshes in the scene into a single mesh
|
139 |
+
meshes = []
|
140 |
+
for geometry in mesh.geometry.values():
|
141 |
+
if isinstance(geometry, trimesh.Trimesh):
|
142 |
+
# Transform mesh to scene coordinates
|
143 |
+
transform = mesh.graph[mesh.graph.nodes_geometry[0]][0]
|
144 |
+
geometry.apply_transform(transform)
|
145 |
+
meshes.append(geometry)
|
146 |
+
|
147 |
+
# Combine all meshes
|
148 |
+
mesh = trimesh.util.concatenate(meshes)
|
149 |
+
|
150 |
+
# Get vertices and compute bounding box
|
151 |
+
vertices = mesh.vertices
|
152 |
+
bbox_min = vertices.min(axis=0)
|
153 |
+
bbox_max = vertices.max(axis=0)
|
154 |
+
|
155 |
+
# Find center and scale
|
156 |
+
center = (bbox_min + bbox_max) * 0.5
|
157 |
+
scale = 2.0 / (bbox_max - bbox_min).max()
|
158 |
+
|
159 |
+
# Center and scale vertices
|
160 |
+
vertices = (vertices - center) * scale
|
161 |
+
|
162 |
+
# Create new mesh with normalized vertices
|
163 |
+
normalized_mesh = trimesh.Trimesh(vertices=vertices,
|
164 |
+
faces=mesh.faces,
|
165 |
+
face_normals=mesh.face_normals,
|
166 |
+
vertex_normals=mesh.vertex_normals,
|
167 |
+
process=False)
|
168 |
+
|
169 |
+
# # Copy texture from original mesh if it exists
|
170 |
+
# if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'material'):
|
171 |
+
# print("copy material")
|
172 |
+
# normalized_mesh.visual.material = mesh.visual.material
|
173 |
+
# if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'texture'):
|
174 |
+
# print("copy texture")
|
175 |
+
# normalized_mesh.visual.texture = mesh.visual.texture
|
176 |
+
# if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv'):
|
177 |
+
# print("copy uv")
|
178 |
+
# normalized_mesh.visual.uv = mesh.visual.uv
|
179 |
+
|
180 |
+
return normalized_mesh
|
181 |
+
|
182 |
+
|
183 |
+
def vis_joint(normalized_mesh_file, joints):
|
184 |
+
if normalized_mesh_file is None or joints is None:
|
185 |
+
return None, None
|
186 |
+
vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints)
|
187 |
+
return vis_file, vis_file
|
188 |
+
|
189 |
+
def vis_connectivity(normalized_mesh_file, joints, conns):
|
190 |
+
if normalized_mesh_file is None or joints is None or conns is None:
|
191 |
+
return None, None
|
192 |
+
vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, conns=conns)
|
193 |
+
return vis_file, vis_file
|
194 |
+
|
195 |
+
def vis_skinning(normalized_mesh_file, joints, conns, skins):
|
196 |
+
if normalized_mesh_file is None or joints is None or conns is None or skins is None:
|
197 |
+
return None, None
|
198 |
+
vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, conns=conns, skins=skins)
|
199 |
+
return vis_file, vis_file
|
200 |
+
|
201 |
+
def prepare_blender_file(normalized_mesh_file):
|
202 |
+
if normalized_mesh_file is None:
|
203 |
+
return None
|
204 |
+
|
205 |
+
if not os.path.exists(normalized_mesh_file) or not os.path.exists(normalized_mesh_file.replace('object.obj', 'joints.pt')) or not os.path.exists(normalized_mesh_file.replace('object.obj', 'conns.pt')) or not os.path.exists(normalized_mesh_file.replace('object.obj', 'skins.pt')):
|
206 |
+
return None
|
207 |
+
|
208 |
+
folder = normalized_mesh_file.replace('object.obj', '')
|
209 |
+
abs_folder = os.path.abspath(folder)
|
210 |
+
os.system(f"python Render.py --path {abs_folder}")
|
211 |
+
|
212 |
+
blender_file = os.path.join(folder, 'blender_output.blend')
|
213 |
+
while not os.path.exists(blender_file):
|
214 |
+
time.sleep(1)
|
215 |
+
|
216 |
+
return blender_file
|
217 |
+
|
218 |
+
|
219 |
+
def process_input(mesh_file):
|
220 |
+
"""
|
221 |
+
Function to handle input changes and initialize visualization
|
222 |
+
|
223 |
+
Args:
|
224 |
+
mesh_file: Path to input mesh file
|
225 |
+
joint_checkpoint: Path to joint prediction checkpoint
|
226 |
+
conn_checkpoint: Path to connectivity prediction checkpoint
|
227 |
+
skin_checkpoint: Path to skinning prediction checkpoint
|
228 |
+
|
229 |
+
Returns:
|
230 |
+
vis_file: Path to visualization file
|
231 |
+
"""
|
232 |
+
|
233 |
+
# For now just visualize the input mesh
|
234 |
+
if mesh_file is None:
|
235 |
+
return None, None, None, None, None, None, None, None
|
236 |
+
|
237 |
+
# make folder for tmp files
|
238 |
+
os.makedirs(f"Anymate/tmp/{mesh_file.split('/')[-1].replace('.obj', '')}", exist_ok=True)
|
239 |
+
|
240 |
+
normalized_mesh = normalize_mesh(obj2mesh(mesh_file))
|
241 |
+
normalized_mesh_file = f"Anymate/tmp/{mesh_file.split('/')[-1].replace('.obj', '')}/object.obj"
|
242 |
+
normalized_mesh.export(normalized_mesh_file)
|
243 |
+
|
244 |
+
# normalized_mesh.export(mesh_file)
|
245 |
+
|
246 |
+
vis_file = visualize_results(mesh_file=normalized_mesh_file)
|
247 |
+
pc = process_mesh_to_pc(normalized_mesh_file)
|
248 |
+
pc = torch.from_numpy(pc).to(anymate_args.device).to(torch.float32)
|
249 |
+
|
250 |
+
# print(pc.shape, pc.max(dim=0), pc.min(dim=0))
|
251 |
+
|
252 |
+
return normalized_mesh_file, vis_file, vis_file, None, pc, None, None, None
|
253 |
+
|
254 |
+
|
255 |
+
def get_model(checkpoint):
|
256 |
+
model = load_checkpoint(checkpoint, anymate_args.device, anymate_args.num_joints)
|
257 |
+
return model, True
|
258 |
+
|
259 |
+
def get_result_joint(mesh_file, model, pc, eps=0.03, min_samples=1):
|
260 |
+
return get_joint(pc, model, device=anymate_args.device, save=mesh_file.replace('object.obj', 'joints.pt'), eps=eps, min_samples=min_samples)
|
261 |
+
|
262 |
+
def get_result_connectivity(mesh_file, model, pc, joints):
|
263 |
+
return get_connectivity(pc, joints, model, device=anymate_args.device, save=mesh_file.replace('object.obj', 'conns.pt'))
|
264 |
+
|
265 |
+
def get_result_skinning(mesh_file, model, pc, joints, conns):
|
266 |
+
# mesh = trimesh.load(mesh_file)
|
267 |
+
mesh = obj2mesh(mesh_file)
|
268 |
+
vertices = torch.from_numpy(mesh.vertices).to(anymate_args.device).to(torch.float32)
|
269 |
+
vertex_normals = torch.from_numpy(mesh.vertex_normals).to(anymate_args.device).to(torch.float32)
|
270 |
+
vertices = torch.cat([vertices, vertex_normals], dim=-1)
|
271 |
+
return get_skinning(pc, joints, conns, model, vertices=vertices, device=anymate_args.device, save=mesh_file.replace('object.obj', 'skins.pt'))
|
272 |
+
|
273 |
+
def get_all_models(checkpoint_joint, checkpoint_conn, checkpoint_skin):
|
274 |
+
model_joint = load_checkpoint(checkpoint_joint, anymate_args.device, anymate_args.num_joints)
|
275 |
+
model_connectivity = load_checkpoint(checkpoint_conn, anymate_args.device, anymate_args.num_joints)
|
276 |
+
model_skin = load_checkpoint(checkpoint_skin, anymate_args.device, anymate_args.num_joints)
|
277 |
+
return model_joint, model_connectivity, model_skin, True, True, True
|
278 |
+
|
279 |
+
def get_all_results(mesh_file, model_joint, model_connectivity, model_skin, pc, eps=0.03, min_samples=1):
|
280 |
+
joints = get_result_joint(mesh_file, model_joint, pc, eps=eps, min_samples=min_samples)
|
281 |
+
conns = get_result_connectivity(mesh_file, model_connectivity, pc, joints)
|
282 |
+
skins = get_result_skinning(mesh_file, model_skin, pc, joints, conns)
|
283 |
+
return joints, conns, skins
|
284 |
+
|
Anymate/utils/ui_utils_bpy.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import trimesh
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
|
5 |
+
from Anymate.utils.utils import load_checkpoint, get_joints, get_connectivity
|
6 |
+
from Anymate.args import anymate_args
|
7 |
+
from Anymate.utils.render_utils import empty, add_co, add_mesh, add_joints, add_conn, add_skin, setup_armature, save_scene
|
8 |
+
|
9 |
+
def visualize_results(mesh_file=None, joints=None, connectivity=None, skinning=None):
|
10 |
+
|
11 |
+
import bpy
|
12 |
+
# Create a scene with both original and processed meshes
|
13 |
+
vis_file = "Anymate/tmp/vis_scene.glb"
|
14 |
+
print('fffffffff')
|
15 |
+
|
16 |
+
# empty()
|
17 |
+
bpy.ops.wm.read_homefile(use_empty=True)
|
18 |
+
|
19 |
+
if mesh_file is not None:
|
20 |
+
# add_mesh(mesh_file)
|
21 |
+
bpy.ops.wm.obj_import(filepath=mesh_file)
|
22 |
+
|
23 |
+
if joints is not None:
|
24 |
+
add_joints(joints)
|
25 |
+
|
26 |
+
if connectivity is not None:
|
27 |
+
add_conn(connectivity, joints)
|
28 |
+
|
29 |
+
if skinning is not None:
|
30 |
+
add_skin(mesh_file, skinning)
|
31 |
+
|
32 |
+
# setup_armature()
|
33 |
+
# save_scene(vis_file)
|
34 |
+
bpy.ops.wm.save_as_mainfile(filepath=vis_file)
|
35 |
+
return vis_file
|
36 |
+
|
37 |
+
|
38 |
+
def process_mesh_to_pc(obj_path, sample_num = 8192, save_path = None):
|
39 |
+
# mesh_list : list of trimesh
|
40 |
+
try :
|
41 |
+
mesh = trimesh.load_mesh(obj_path)
|
42 |
+
|
43 |
+
points, face_idx = mesh.sample(sample_num, return_index=True)
|
44 |
+
normals = mesh.face_normals[face_idx]
|
45 |
+
|
46 |
+
pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
|
47 |
+
|
48 |
+
|
49 |
+
if save_path is not None:
|
50 |
+
np.save(save_path, pc_normal)
|
51 |
+
|
52 |
+
return pc_normal
|
53 |
+
except Exception as e:
|
54 |
+
print(f"Error: {obj_path} {e}")
|
55 |
+
return None
|
56 |
+
|
57 |
+
|
58 |
+
def normalize_mesh(mesh):
|
59 |
+
# Get vertices and compute bounding box
|
60 |
+
vertices = mesh.vertices
|
61 |
+
bbox_min = vertices.min(axis=0)
|
62 |
+
bbox_max = vertices.max(axis=0)
|
63 |
+
|
64 |
+
# Find center and scale
|
65 |
+
center = (bbox_min + bbox_max) * 0.5
|
66 |
+
scale = 2.0 / (bbox_max - bbox_min).max()
|
67 |
+
|
68 |
+
# Center and scale vertices
|
69 |
+
vertices = (vertices - center) * scale
|
70 |
+
|
71 |
+
# Create new mesh with normalized vertices
|
72 |
+
normalized_mesh = trimesh.Trimesh(vertices=vertices,
|
73 |
+
faces=mesh.faces,
|
74 |
+
face_normals=mesh.face_normals,
|
75 |
+
vertex_normals=mesh.vertex_normals)
|
76 |
+
|
77 |
+
return normalized_mesh
|
78 |
+
|
79 |
+
|
80 |
+
def vis_joint(normalized_mesh_file, joints):
|
81 |
+
vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints)
|
82 |
+
return vis_file
|
83 |
+
|
84 |
+
def vis_connectivity(normalized_mesh_file, joints, connectivity):
|
85 |
+
vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, connectivity=connectivity)
|
86 |
+
return vis_file
|
87 |
+
|
88 |
+
def vis_skinning(skinning):
|
89 |
+
vis_file = visualize_results(skinning=skinning)
|
90 |
+
return vis_file
|
91 |
+
|
92 |
+
|
93 |
+
def process_input(mesh_file):
|
94 |
+
"""
|
95 |
+
Function to handle input changes and initialize visualization
|
96 |
+
|
97 |
+
Args:
|
98 |
+
mesh_file: Path to input mesh file
|
99 |
+
joint_checkpoint: Path to joint prediction checkpoint
|
100 |
+
conn_checkpoint: Path to connectivity prediction checkpoint
|
101 |
+
skin_checkpoint: Path to skinning prediction checkpoint
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
vis_file: Path to visualization file
|
105 |
+
"""
|
106 |
+
|
107 |
+
# For now just visualize the input mesh
|
108 |
+
|
109 |
+
normalized_mesh = normalize_mesh(trimesh.load(mesh_file))
|
110 |
+
normalized_mesh_file = "Anymate/tmp/normalized_mesh.obj"
|
111 |
+
normalized_mesh.export(normalized_mesh_file)
|
112 |
+
vis_file = visualize_results(mesh_file=normalized_mesh_file)
|
113 |
+
pc = process_mesh_to_pc(normalized_mesh_file)
|
114 |
+
pc = torch.from_numpy(pc).to(anymate_args.device).to(torch.float32)
|
115 |
+
|
116 |
+
print(pc.shape, pc.max(dim=0), pc.min(dim=0))
|
117 |
+
|
118 |
+
return normalized_mesh_file, vis_file, pc, None, None, None
|
119 |
+
|
120 |
+
|
121 |
+
def get_model(checkpoint):
|
122 |
+
model = load_checkpoint(checkpoint, anymate_args.device, anymate_args.num_joints)
|
123 |
+
return model, True
|
124 |
+
|
125 |
+
def get_result_joint(model, pc):
|
126 |
+
return get_joints(pc, model, anymate_args.device)
|
127 |
+
|
128 |
+
def get_result_connectivity(model, pc, joints):
|
129 |
+
return get_connectivity(pc, joints, model, anymate_args.device)
|
130 |
+
|
131 |
+
def get_result_skinning(model, pc):
|
132 |
+
with torch.no_grad():
|
133 |
+
skinning = model(pc)
|
134 |
+
return skinning
|
Anymate/utils/utils.py
ADDED
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
from Anymate.model import EncoderDecoder
|
3 |
+
from sklearn.cluster import DBSCAN
|
4 |
+
|
5 |
+
def load_checkpoint(path, device, num_joints):
|
6 |
+
print(f"Loading model from {path}")
|
7 |
+
model_state = torch.load(path)
|
8 |
+
model_weights = model_state['state_dict']
|
9 |
+
|
10 |
+
try:
|
11 |
+
model_config = model_state['model_config']
|
12 |
+
model = EncoderDecoder(device=device, dtype=torch.float32, **model_config)
|
13 |
+
model.to(device)
|
14 |
+
model.load_state_dict(model_weights, strict=True)
|
15 |
+
except:
|
16 |
+
encoder = path.split('/')[-1].split('.')[0].split('-')[0]
|
17 |
+
decoder = path.split('/')[-1].split('.')[0].split('-')[1]
|
18 |
+
model = EncoderDecoder(encoder=encoder, decoder=decoder, device=device, dtype=torch.float32, num_joints=num_joints)
|
19 |
+
model.to(device)
|
20 |
+
model.load_state_dict(model_weights, strict=True)
|
21 |
+
|
22 |
+
print(f"Loaded model from {path}")
|
23 |
+
|
24 |
+
return model
|
25 |
+
|
26 |
+
def get_joint(pc, model, device='cuda', save=None, vox=None, eps=0.03, min_samples=1):
|
27 |
+
model.eval()
|
28 |
+
data = {'points_cloud': pc.unsqueeze(0)}
|
29 |
+
if vox is not None:
|
30 |
+
data['vox'] = vox.unsqueeze(0)
|
31 |
+
with torch.no_grad():
|
32 |
+
model.decoder.inference_mode(eps=eps, min_samples=min_samples)
|
33 |
+
joints = model(data, device=device)
|
34 |
+
joints = torch.tensor(joints, dtype=torch.float32).to(device)
|
35 |
+
|
36 |
+
if save is not None:
|
37 |
+
torch.save(joints, save)
|
38 |
+
|
39 |
+
return joints
|
40 |
+
|
41 |
+
def get_connectivity(pc, joints, model, device='cuda',return_prob=False, save=None):
|
42 |
+
model.eval()
|
43 |
+
data = {'points_cloud': pc.unsqueeze(0), 'joints': joints.unsqueeze(0), 'joints_num': torch.tensor([joints.shape[0]]),
|
44 |
+
'joints_mask': torch.ones(joints.shape[0], device=device).unsqueeze(0)}
|
45 |
+
with torch.no_grad():
|
46 |
+
conns = model(data, device=device).softmax(dim=-1)
|
47 |
+
conns = conns.squeeze(0) if return_prob else torch.argmax(conns, dim=-1).squeeze(0)
|
48 |
+
|
49 |
+
if save is not None:
|
50 |
+
torch.save(conns, save)
|
51 |
+
|
52 |
+
return conns
|
53 |
+
|
54 |
+
def get_skinning(pc, joints, conns, model, vertices=None, bones=None, device='cuda', save=None):
|
55 |
+
model.eval()
|
56 |
+
|
57 |
+
if bones is None:
|
58 |
+
bones = []
|
59 |
+
for i in range(joints.shape[0]):
|
60 |
+
if conns[i] != i:
|
61 |
+
bones.append(torch.cat((joints[conns[i]], joints[i]), dim=-1))
|
62 |
+
bones = torch.stack(bones, dim=0)
|
63 |
+
|
64 |
+
data = {'points_cloud': pc.unsqueeze(0), 'bones': bones.unsqueeze(0), 'bones_num': torch.tensor([bones.shape[0]]),
|
65 |
+
'bones_mask': torch.ones(bones.shape[0], device=device).unsqueeze(0)}
|
66 |
+
|
67 |
+
if vertices is not None:
|
68 |
+
data['vertices'] = vertices.unsqueeze(0)
|
69 |
+
model.decoder.inference = True
|
70 |
+
|
71 |
+
with torch.no_grad():
|
72 |
+
skins = model(data, device=device).softmax(dim=-1).squeeze(0)
|
73 |
+
|
74 |
+
if save is not None:
|
75 |
+
torch.save(skins, save)
|
76 |
+
|
77 |
+
return skins
|
Anymate/utils/vol_utils.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import torch
|
3 |
+
from ThirdParty.michelangelo.graphics.primitives import generate_dense_grid_points
|
4 |
+
from sklearn.cluster import DBSCAN
|
5 |
+
|
6 |
+
def get_vol(bounds=(-0.5, 0.0, -0.5, 0.5, 1.0, 0.5), octree_depth=6):
|
7 |
+
|
8 |
+
bbox_min = np.array(bounds[0:3])
|
9 |
+
bbox_max = np.array(bounds[3:6])
|
10 |
+
bbox_size = bbox_max - bbox_min
|
11 |
+
|
12 |
+
xyz_samples, grid_size, length = generate_dense_grid_points(
|
13 |
+
bbox_min=bbox_min,
|
14 |
+
bbox_max=bbox_max,
|
15 |
+
octree_depth=octree_depth,
|
16 |
+
indexing="ij"
|
17 |
+
)
|
18 |
+
xyz_samples = torch.FloatTensor(xyz_samples) # ((2^d)+1)^3
|
19 |
+
|
20 |
+
return xyz_samples
|
21 |
+
|
22 |
+
def get_co(vox, bounds=(-1.0, -1.0, -1.0, 1.0, 1.0, 1.0), dtype = torch.float32):
|
23 |
+
|
24 |
+
bbox_min = torch.tensor(bounds[0:3]).to(vox.device)
|
25 |
+
bbox_max = torch.tensor(bounds[3:6]).to(vox.device)
|
26 |
+
bbox_size = bbox_max - bbox_min
|
27 |
+
|
28 |
+
# ind = torch.argwhere(vox)
|
29 |
+
# ind = ind.to(dtype) / (vox.shape[0]) * bbox_size + bbox_min
|
30 |
+
ind = vox
|
31 |
+
ind = ind.to(dtype) / 64 * bbox_size + bbox_min
|
32 |
+
|
33 |
+
return ind.to(dtype)
|
34 |
+
|
35 |
+
def get_gt(vol, joints, octree_depth=6):
|
36 |
+
sigma = 2 / 2**octree_depth
|
37 |
+
|
38 |
+
dist = torch.cdist(vol, joints)
|
39 |
+
# print(dist.min(), dist.max())
|
40 |
+
|
41 |
+
dist = dist.min(dim=1).values
|
42 |
+
|
43 |
+
gt = torch.exp(-dist**2 / 2 / sigma**2)
|
44 |
+
|
45 |
+
return gt
|
46 |
+
|
47 |
+
def project_onto_planes(planes, coordinates):
|
48 |
+
"""
|
49 |
+
Does a projection of a 3D point onto a batch of 2D planes,
|
50 |
+
returning 2D plane coordinates.
|
51 |
+
|
52 |
+
Takes plane axes of shape n_planes, 3, 3
|
53 |
+
# Takes coordinates of shape N, M, 3
|
54 |
+
# returns projections of shape N*n_planes, M, 2
|
55 |
+
"""
|
56 |
+
N, M, C = coordinates.shape
|
57 |
+
n_planes, _, _ = planes.shape
|
58 |
+
coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
|
59 |
+
inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
|
60 |
+
projections = torch.bmm(coordinates, inv_planes)
|
61 |
+
return projections[..., :2]
|
62 |
+
|
63 |
+
def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
|
64 |
+
assert padding_mode == 'zeros'
|
65 |
+
N, n_planes, C, H, W = plane_features.shape
|
66 |
+
_, M, _ = coordinates.shape
|
67 |
+
plane_features = plane_features.view(N*n_planes, C, H, W)
|
68 |
+
|
69 |
+
# coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds
|
70 |
+
|
71 |
+
projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
|
72 |
+
output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
|
73 |
+
return output_features
|
74 |
+
|
75 |
+
def generate_planes():
|
76 |
+
"""
|
77 |
+
Defines planes by the three vectors that form the "axes" of the
|
78 |
+
plane. Should work with arbitrary number of planes and planes of
|
79 |
+
arbitrary orientation.
|
80 |
+
"""
|
81 |
+
return torch.tensor([[[1, 0, 0],
|
82 |
+
[0, 1, 0],
|
83 |
+
[0, 0, 1]],
|
84 |
+
[[1, 0, 0],
|
85 |
+
[0, 0, 1],
|
86 |
+
[0, 1, 0]],
|
87 |
+
[[0, 0, 1],
|
88 |
+
[1, 0, 0],
|
89 |
+
[0, 1, 0]]], dtype=torch.float32)
|
90 |
+
|
91 |
+
def extract_keypoints(y_pred, vox):
|
92 |
+
|
93 |
+
y_pred = y_pred.detach().cpu().numpy()
|
94 |
+
vox = vox.detach().cpu().numpy()
|
95 |
+
volume = np.zeros([64, 64, 64])
|
96 |
+
volume[...] = -100
|
97 |
+
volume[vox[:, 0], vox[:, 1], vox[:, 2]] = y_pred.squeeze(-1)
|
98 |
+
|
99 |
+
clusters = []
|
100 |
+
cluster_model = DBSCAN(eps=1.8, min_samples=1)
|
101 |
+
|
102 |
+
level = min((0.85 * y_pred.max() + 0.15 * y_pred.min()).item(), 0)
|
103 |
+
potential_points = np.argwhere(volume >= level)
|
104 |
+
clustering = cluster_model.fit(potential_points)
|
105 |
+
for cluster in set(clustering.labels_):
|
106 |
+
if cluster == -1:
|
107 |
+
print('got noise', len(potential_points[clustering.labels_ == cluster]))
|
108 |
+
continue
|
109 |
+
clusters.append(potential_points[clustering.labels_ == cluster])
|
110 |
+
|
111 |
+
while True:
|
112 |
+
if np.all(np.array([(len(cluster) < 10) for cluster in clusters])):
|
113 |
+
break
|
114 |
+
new_clusters = []
|
115 |
+
for points in clusters:
|
116 |
+
if len(points) < 10:
|
117 |
+
new_clusters.append(points)
|
118 |
+
continue
|
119 |
+
|
120 |
+
value = volume[points[:, 0], points[:, 1], points[:, 2]]
|
121 |
+
|
122 |
+
potential_points = points[value >= (0.1 * value.max() + 0.9 * value.min())]
|
123 |
+
clustering = cluster_model.fit(potential_points)
|
124 |
+
for cluster in set(clustering.labels_):
|
125 |
+
if cluster == -1:
|
126 |
+
print('got noise', len(potential_points[clustering.labels_ == cluster]))
|
127 |
+
continue
|
128 |
+
new_clusters.append(potential_points[clustering.labels_ == cluster])
|
129 |
+
|
130 |
+
clusters = new_clusters
|
131 |
+
|
132 |
+
key_points = np.array([cluster.mean(axis=0) for cluster in clusters])
|
133 |
+
key_points = key_points / 32 - 1
|
134 |
+
|
135 |
+
return key_points
|
Render.py
ADDED
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
import bpy
|
3 |
+
import mathutils
|
4 |
+
from Anymate.utils.render_utils import empty, setup_armature
|
5 |
+
|
6 |
+
|
7 |
+
def parse_args():
|
8 |
+
parser = argparse.ArgumentParser(description='Anymate rendering script')
|
9 |
+
parser.add_argument('--path', type=str, required=True, help='Path to the model file')
|
10 |
+
return parser.parse_args()
|
11 |
+
|
12 |
+
args = parse_args()
|
13 |
+
|
14 |
+
print(f"Starting converting {args.path} to blender format...")
|
15 |
+
|
16 |
+
empty()
|
17 |
+
setup_armature(args.path)
|
ThirdParty/PointLLM/.gitignore
ADDED
@@ -0,0 +1,12 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
__pycache__
|
2 |
+
*.egg-info
|
3 |
+
.vscode
|
4 |
+
checkpoints
|
5 |
+
outputs
|
6 |
+
wandb
|
7 |
+
anno_data
|
8 |
+
objaverse_data
|
9 |
+
modelnet40_data
|
10 |
+
*.zip
|
11 |
+
*.ipynb
|
12 |
+
serving_workdirs
|
ThirdParty/PointLLM/README.md
ADDED
@@ -0,0 +1,353 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
<br>
|
2 |
+
<p align="center">
|
3 |
+
<h1 align="center"><img src="assets/icon.png" align="center" width="6.5%"><strong>PointLLM: Empowering Large Language Models to Understand Point Clouds</strong></h1>
|
4 |
+
<p align="center">
|
5 |
+
<a href='https://runsenxu.com/' target='_blank'>Runsen Xu</a> 
|
6 |
+
<a href='https://guanfang12.github.io/' target='_blank'>Xiaolong Wang</a> 
|
7 |
+
<a href='https://tai-wang.github.io/' target='_blank'>Tai Wang</a> 
|
8 |
+
<a href='http://yilunchen.com/about' target='_blank'>Yilun Chen</a> 
|
9 |
+
<a href='https://oceanpang.github.io/' target='_blank'>Jiangmiao Pang*</a> 
|
10 |
+
<a href='http://dahua.site/' target='_blank'>Dahua Lin</a> 
|
11 |
+
<br>
|
12 |
+
The Chinese University of Hong Kong Shanghai AI Laboratory Zhejiang University
|
13 |
+
</p>
|
14 |
+
</p>
|
15 |
+
|
16 |
+
<p align="center">
|
17 |
+
<a href="http://arxiv.org/abs/2308.16911" target='_**blank**'>
|
18 |
+
<img src="https://img.shields.io/badge/arXiv-2308.16911-blue?">
|
19 |
+
</a>
|
20 |
+
<a href="https://arxiv.org/pdf/2308.16911.pdf" target='_blank'>
|
21 |
+
<img src="https://img.shields.io/badge/Paper-📖-blue?">
|
22 |
+
</a>
|
23 |
+
<a href="https://runsenxu.com/projects/PointLLM" target='_blank'>
|
24 |
+
<img src="https://img.shields.io/badge/Project-🚀-blue">
|
25 |
+
</a>
|
26 |
+
<a href="http://101.230.144.196" target='_blank'>
|
27 |
+
<img src="https://img.shields.io/badge/Demo-🤗-blue">
|
28 |
+
</a>
|
29 |
+
<a href="" target='_blank'>
|
30 |
+
<img src="https://visitor-badge.laobi.icu/badge?page_id=OpenRobotLab.pointllm&left_color=gray&right_color=blue">
|
31 |
+
</a>
|
32 |
+
<a href="https://openxlab.org.cn/apps/detail/openxlab-app/PointLLM" target='_blank'>
|
33 |
+
<img src="https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg">
|
34 |
+
</a>
|
35 |
+
</p>
|
36 |
+
|
37 |
+
## 🏠 About
|
38 |
+
<!--  -->
|
39 |
+
<div style="text-align: center;">
|
40 |
+
<img src="assets/teaser.jpg" alt="Dialogue_Teaser" width=100% >
|
41 |
+
</div>
|
42 |
+
We introduce <b>PointLLM, a multi-modal large language model capable of understanding colored point clouds of objects.</b> It perceives object types, geometric structures, and appearance without concerns for ambiguous depth, occlusion, or viewpoint dependency. <b>We collect a novel dataset comprising 660K simple and 70K complex point-text instruction pairs</b> to enable a two-stage training strategy. To rigorously evaluate our model's perceptual abilities and its generalization capabilities, <b>we establish two benchmarks: Generative 3D Object Classification and 3D Object Captioning, assessed through three different evaluation methods.</b>
|
43 |
+
|
44 |
+
## 🔥 News
|
45 |
+
- [2024-09-06] We have uploaded the camera-ready version of PointLLM for ECCV 2024, which includes clearer writing and additional experimental results. Please check the paper [here](https://arxiv.org/abs/2308.16911).
|
46 |
+
- [2024-07-01] PointLLM has been accepted by ECCV 2024 with all "strong-accept" recommendation. 🎉 We are looking for self-motivated students to conduct research regarding PointLLM. Please send an email to [email protected] with your CV if you are interested!
|
47 |
+
- [2023-12-29] We release the codes of our online Gradio demo.
|
48 |
+
- [2023-12-26] We release the codes for model evaluation, including ChatGPT/GPT-4 evaluation and traditional metric evaluation.
|
49 |
+
- [2023-12-08] We release the codes for training and PointLLM-v1.2. The online demo has also been upgraded to the v1.2 version. Please enjoy! 🎉
|
50 |
+
- [2023-12-01] We have released an updated version of our paper (v2), which includes additional baseline comparisons, enhanced human-evaluation metrics, improved model performance (PointLLM-v1.2), and other refinements. Please check the updated version [here](https://arxiv.org/abs/2308.16911).
|
51 |
+
- [2023-10-18] We release our instruction-following data, including both the simple-description and complex instructions. Download [here](https://huggingface.co/datasets/RunsenXu/PointLLM).
|
52 |
+
- [2023-09-26] We release the inferencing codes with checkpoints as well as the Objaverse colored point cloud files we use. You can chat with PointLLM with your own machines.
|
53 |
+
- [2023-08-31] We release the [paper](http://arxiv.org/abs/2308.16911) of PointLLM and an online gradio [demo](http://101.230.144.196). Try it! 🎉
|
54 |
+
|
55 |
+
<!-- contents with emoji -->
|
56 |
+
## 📋 Contents
|
57 |
+
- [🤖 Online Demo](#-online-demo)
|
58 |
+
- [💬 Dialogue Examples](#-dialogue-examples)
|
59 |
+
- [🔍 Overview](#-overview)
|
60 |
+
- [📦 Training and Evaluation](#-training-and-evaluation)
|
61 |
+
- [📝 TODO List](#-todo-list)
|
62 |
+
- [🔗 Citation](#-citation)
|
63 |
+
- [📄 License](#-license)
|
64 |
+
- [📚 Related Work](#-related-work)
|
65 |
+
- [👏 Acknowledgements](#-acknowledgements)
|
66 |
+
|
67 |
+
## 🤖 Online Demo
|
68 |
+
<b>PointLLM is online! Try it at [http://101.230.144.196](http://101.230.144.196) or at [OpenXLab/PointLLM](https://openxlab.org.cn/apps/detail/openxlab-app/PointLLM).</b>
|
69 |
+
|
70 |
+
You can chat with PointLLM about the models of the [Objaverse](https://objaverse.allenai.org) dataset or about your own point clouds!
|
71 |
+
|
72 |
+
Please do not hesitate to tell us if you have any feedback! 😃
|
73 |
+
|
74 |
+
## 💬 Dialogue Examples
|
75 |
+
| Dialogue 1 | Dialogue 2| Dialogue 3 | Dialogue 4
|
76 |
+
| :-: | :-: | :-: | :-: |
|
77 |
+
| <img width="100%" src="assets/dialogue_1.jpg"> | <img width="100%" src="assets/dialogue_2.jpg"> | <img width="100%" src="assets/dialogue_3.jpg"> | <img width="100%" src="assets/dialogue_4.jpg"> |
|
78 |
+
|
79 |
+
|
80 |
+
## 🔍 Overview
|
81 |
+
|
82 |
+
### Model
|
83 |
+
<p align="center">
|
84 |
+
<img src="assets/model.jpg" align="center" width="100%">
|
85 |
+
</p>
|
86 |
+
The point encoder extracts features from the input point cloud and projects them to the latent space of the LLM backbone. The LLM backbone processes sequences of point tokens and text tokens, and generates the predicted tokens as the output.
|
87 |
+
|
88 |
+
### Experiment Results
|
89 |
+
#### Quantitative Comparisons with baselines.
|
90 |
+
Please refer to our paper for more results.
|
91 |
+
<p align="center">
|
92 |
+
<img src="assets/cls_results.png" align="center" width="100%">
|
93 |
+
</p>
|
94 |
+
<p align="center">
|
95 |
+
<img src="assets/caption_results.png" align="center" width="100%">
|
96 |
+
</p>
|
97 |
+
<b>!!!Note: Traditional metrics such as BLEU-1, ROUGE-L, and METEOR tend to favor shorter responses and may not effectively capture semantic accuracy. For a detailed discussion on this, please refer to our paper. We suggest the community not solely rely on these metrics for evaluation.</b>
|
98 |
+
|
99 |
+
#### Qualitative Comparisons with baselines.
|
100 |
+
Please refer to our paper for more results.
|
101 |
+
<p align="center">
|
102 |
+
<img src="assets/qualitative_comparisons_v2.png" align="center" width="100%">
|
103 |
+
</p>
|
104 |
+
|
105 |
+
## 📦 Training and Evaluation
|
106 |
+
### Installation
|
107 |
+
We test our codes under the following environment:
|
108 |
+
- Ubuntu 20.04
|
109 |
+
- NVIDIA Driver: 515.65.01
|
110 |
+
- CUDA 11.7
|
111 |
+
- Python 3.10.13
|
112 |
+
- PyTorch 2.0.1
|
113 |
+
- Transformers 4.28.0.dev(transformers.git@cae78c46)
|
114 |
+
|
115 |
+
To start:
|
116 |
+
1. Clone this repository.
|
117 |
+
```bash
|
118 |
+
git clone [email protected]:OpenRobotLab/PointLLM.git
|
119 |
+
cd PointLLM
|
120 |
+
```
|
121 |
+
2. Install packages
|
122 |
+
```bash
|
123 |
+
conda create -n pointllm python=3.10 -y
|
124 |
+
conda activate pointllm
|
125 |
+
pip install --upgrade pip # enable PEP 660 support
|
126 |
+
pip install -e .
|
127 |
+
|
128 |
+
# * for training
|
129 |
+
pip install ninja
|
130 |
+
pip install flash-attn
|
131 |
+
```
|
132 |
+
|
133 |
+
### Data Preparation
|
134 |
+
#### Objaverse Training Data
|
135 |
+
1. Download the two compressed files of 660K Objaverse colored point clouds [here](https://huggingface.co/datasets/RunsenXu/PointLLM/tree/main). They require about 77GB of storage space.
|
136 |
+
2. Run the following command to merge the two files into one and uncompress it. This will produce a folder named `8192_npy` containing 660K point cloud files named `{Objaverse_ID}_8192.npy`. Each file is a numpy array with dimensions (8192, 6), where the first three dimensions are `xyz` and the last three dimensions are `rgb` in [0, 1] range.
|
137 |
+
```bash
|
138 |
+
cat Objaverse_660K_8192_npy_split_a* > Objaverse_660K_8192_npy.tar.gz
|
139 |
+
tar -xvf Objaverse_660K_8192_npy.tar.gz
|
140 |
+
```
|
141 |
+
3. In `PointLLM` folder, create a folder `data` and create a soft link to the uncompressed file in the directory.
|
142 |
+
```bash
|
143 |
+
cd PointLLM
|
144 |
+
mkdir data
|
145 |
+
ln -s /path/to/8192_npy data/objaverse_data
|
146 |
+
```
|
147 |
+
|
148 |
+
#### Instruction-Following Data
|
149 |
+
1. In `PointLLM/data` folder, create a directory named `anno_data`.
|
150 |
+
2. Our instruction-following data, including both the simple-description and complex instructions, can be downloaded [here](https://huggingface.co/datasets/RunsenXu/PointLLM). If you have difficulty downloading the data (e.g. network issue), please email the authors.
|
151 |
+
- The simple-description data has 660K samples and the complex instructions have 70K samples.
|
152 |
+
- Both training data are based on the Objaverse dataset.
|
153 |
+
- The complex instructions are generated with GPT-4.
|
154 |
+
3. Put the data files in the `anno_data` directory. The directory should look like this:
|
155 |
+
```bash
|
156 |
+
PointLLM/data/anno_data
|
157 |
+
├── PointLLM_brief_description_660K_filtered.json
|
158 |
+
├── PointLLM_brief_description_660K.json
|
159 |
+
└── PointLLM_complex_instruction_70K.json
|
160 |
+
```
|
161 |
+
4. Note, the `PointLLM_brief_description_660K_filtered.json` is filtered from `PointLLM_brief_description_660K.json` by removing the 3000 objects we reserved as the validation set. If you want to reproduce the results in our paper, you should use the `PointLLM_brief_description_660K_filtered.json` for training. The `PointLLM_complex_instruction_70K.json` contains objects from the training set.
|
162 |
+
5. If you want to generate the complex instructions by yourself, please refer to our paper for other details. The system prompt is at `pointllm/data/data_generation/system_prompt_gpt4_0613.txt`.
|
163 |
+
|
164 |
+
#### Evaluation Data
|
165 |
+
1. Download the referencing GT `PointLLM_brief_description_val_200_GT.json` we use for the benchmarks on Objaverse dataset [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/PointLLM_brief_description_val_200_GT.json), and put it in `PointLLM/data/anno_data`. We also provide the 3000 object ids we filter during training [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/val_object_ids_3000.txt) and their corresponding referencing GT [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/PointLLM_brief_description_val_3000_GT.json), which can be used to evaluate on all the 3000 objects.
|
166 |
+
2. Create a directory named `modelnet40_data` in `PointLLM/data`. Download the test split of ModelNet40 point clouds `modelnet40_test_8192pts_fps.dat` [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/modelnet40_test_8192pts_fps.dat) and put it in `PointLLM/data/modelnet40_data`.
|
167 |
+
|
168 |
+
### Training
|
169 |
+
#### Download the Initial LLM and Point Encoder Weights
|
170 |
+
1. In `PointLLM` folder, create a directory named `checkpoints`.
|
171 |
+
2. Download the pre-trained LLM and point encoder: [
|
172 |
+
PointLLM_7B_v1.1_init](https://huggingface.co/RunsenXu/PointLLM_7B_v1.1_init/tree/main) or [PointLLM_13B_v1.1_init](https://huggingface.co/RunsenXu/PointLLM_13B_v1.1_init/tree/main). Put them in the `checkpoints` directory.
|
173 |
+
3. Note that the above "v1.1" means we use the Vicuna-v1.1 checkpoints, and you do **not** need to download the original LLaMA weights again.
|
174 |
+
|
175 |
+
#### Start Training
|
176 |
+
1. For stage-1 training, simply run:
|
177 |
+
```bash
|
178 |
+
cd PointLLM
|
179 |
+
scripts/PointLLM_train_stage1.sh
|
180 |
+
```
|
181 |
+
2. After stage-1 training, start stage-2 training:
|
182 |
+
```bash
|
183 |
+
scripts/PointLLM_train_stage2.sh
|
184 |
+
```
|
185 |
+
|
186 |
+
#### PointLLM-v1.1 and PointLLM-v1.2
|
187 |
+
Usually, you do not have to care about the following contents. They are only for reproducing the results in our v1 paper (PointLLM-v1.1). If you want to compare with our models or use our models for downstream tasks, please use PointLLM-v1.2 (refer to our v2 paper), which has better performance.
|
188 |
+
<details>
|
189 |
+
<summary>The following steps are for reproducing PointLLM-v1.1 (click to expand)</summary>
|
190 |
+
|
191 |
+
1. PointLLM v1.1 and v1.2 use slightly different pre-trained point encoders and projectors. If you want to reproduce PointLLM v1.1, edit the `config.json` file in the directory of initial LLM and point encoder weights, for example, `vim checkpoints/PointLLM_7B_v1.1_init/config.json`.
|
192 |
+
|
193 |
+
2. Change the key `"point_backbone_config_name"` to specify another point encoder config:
|
194 |
+
```bash
|
195 |
+
# change from
|
196 |
+
"point_backbone_config_name": "PointTransformer_8192point_2layer" # v1.2
|
197 |
+
# to
|
198 |
+
"point_backbone_config_name": "PointTransformer_base_8192point", # v1.1
|
199 |
+
```
|
200 |
+
|
201 |
+
3. Edit the checkpoint path of the point encoder in `scripts/train_stage1.sh`:
|
202 |
+
```bash
|
203 |
+
# change from
|
204 |
+
point_backbone_ckpt=$model_name_or_path/point_bert_v1.2.pt # v1.2
|
205 |
+
# to
|
206 |
+
point_backbone_ckpt=$model_name_or_path/point_bert_v1.1.pt # v1.1
|
207 |
+
```
|
208 |
+
</details>
|
209 |
+
|
210 |
+
### Chatting
|
211 |
+
1. The trained model checkpoints are available [here](https://huggingface.co/RunsenXu) (including different versions of PointLLM).
|
212 |
+
2. Run the following command to launch a chatbot using the `torch.float32` data type for chatting about 3D models of Objaverse. The model checkpoints will be downloaded automatically. You can also manually download the model checkpoints and specify their paths. Here is an example:
|
213 |
+
```bash
|
214 |
+
cd PointLLM
|
215 |
+
PYTHONPATH=$PWD python pointllm/eval/PointLLM_chat.py --model_name RunsenXu/PointLLM_7B_v1.2 --data_name data/objaverse_data --torch_dtype float32
|
216 |
+
```
|
217 |
+
3. You can also easily modify the codes for using point clouds other than those from Objaverse, as long as the point clouds input to the model have dimensions (N, 6), where the first three dimensions are `xyz` and the last three dimensions are `rgb` (in [0, 1] range). You may sample the point clouds to have 8192 points, as our model is trained on such point clouds.
|
218 |
+
4. The following table shows GPU requirements for different models and data types. We recommend using `torch.bfloat16` if applicable, which is used in the experiments in our paper.
|
219 |
+
|
220 |
+
| Model | Data Type | GPU Memory |
|
221 |
+
|:--------:|:---------:|:----------:|
|
222 |
+
| PointLLM-7B | torch.float16 | 14GB |
|
223 |
+
| PointLLM-7B | torch.float32 | 28GB |
|
224 |
+
| PointLLM-13B | torch.float16 | 26GB |
|
225 |
+
| PointLLM-13B | torch.float32 | 52GB |
|
226 |
+
|
227 |
+
### Gradio Demo
|
228 |
+
1. We provide the codes for our online Gradio demo. You can run the following commands to launch the demo locally for chatting and visualization.
|
229 |
+
```bash
|
230 |
+
cd PointLLM
|
231 |
+
PYTHONPATH=$PWD python pointllm/eval/chat_gradio.py --model_name RunsenXu/PointLLM_7B_v1.2 --data_name data/objaverse_data
|
232 |
+
```
|
233 |
+
2. Kind remind: if you want to release the demo in public, please refer to https://www.gradio.app/guides/sharing-your-app#security-and-file-access.
|
234 |
+
|
235 |
+
### Evaluation
|
236 |
+
#### Inferencing
|
237 |
+
1. Run the following commands to infer the results.
|
238 |
+
2. Different commands for inferencing on different benchmarks (PointLLM_7B_v1.2 as an example):
|
239 |
+
```bash
|
240 |
+
cd PointLLM
|
241 |
+
export PYTHONPATH=$PWD
|
242 |
+
|
243 |
+
# Open Vocabulary Classification on Objaverse
|
244 |
+
python pointllm/eval/eval_objaverse.py --model_name RunsenXu/PointLLM_7B_v1.2 --task_type classification --prompt_index 0 # or --prompt_index 1
|
245 |
+
|
246 |
+
# Object captioning on Objaverse
|
247 |
+
python pointllm/eval/eval_objaverse.py --model_name RunsenXu/PointLLM_7B_v1.2 --task_type captioning --prompt_index 2
|
248 |
+
|
249 |
+
# Close-set Zero-shot Classification on ModelNet40
|
250 |
+
python pointllm/eval/eval_modelnet_cls.py --model_name RunsenXu/PointLLM_7B_v1.2 --prompt_index 0 # or --prompt_index 1
|
251 |
+
```
|
252 |
+
3. Please check the default command-line arguments of these two scripts. You can specify different prompts, data paths, and other parameters.
|
253 |
+
4. After inferencing, the results will be saved in `{model_name}/evaluation` as a dict with the following format:
|
254 |
+
```bash
|
255 |
+
{
|
256 |
+
"prompt": "",
|
257 |
+
"results": [
|
258 |
+
{
|
259 |
+
"object_id": "",
|
260 |
+
"ground_truth": "",
|
261 |
+
"model_output": "",
|
262 |
+
"label_name": "" # only for classification on modelnet40
|
263 |
+
}
|
264 |
+
]
|
265 |
+
}
|
266 |
+
```
|
267 |
+
|
268 |
+
#### ChatGPT/GPT-4 Evaluation
|
269 |
+
1. Get your OpenAI API key at [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys).
|
270 |
+
2. Run the following commands to evaluate the model outputs in parallel with ChatGPT/GPT-4 (which cost approximately $1.5 to $2.2 USD).
|
271 |
+
```bash
|
272 |
+
cd PointLLM
|
273 |
+
export PYTHONPATH=$PWD
|
274 |
+
export OPENAI_API_KEY=sk-****
|
275 |
+
|
276 |
+
# Open Vocabulary Classification on Objaverse
|
277 |
+
python pointllm/eval/evaluator.py --results_path /path/to/model_output --model_type gpt-4-0613 --eval_type open-free-form-classification --parallel --num_workers 15
|
278 |
+
|
279 |
+
# Object captioning on Objaverse
|
280 |
+
python pointllm/eval/evaluator.py --results_path /path/to/model_output --model_type gpt-4-0613 --eval_type object-captioning --parallel --num_workers 15
|
281 |
+
|
282 |
+
# Close-set Zero-shot Classification on ModelNet40
|
283 |
+
python pointllm/eval/evaluator.py --results_path /path/to/model_output --model_type gpt-3.5-turbo-0613 --eval_type modelnet-close-set-classification --parallel --num_workers 15
|
284 |
+
```
|
285 |
+
3. The evaluation script supports interruption and resumption. You can interrupt the evaluation process at any time by using `Ctrl+C`. This will save the temporary results. If an error occurs during the evaluation, the script will also save the current state. You can resume the evaluation from where it left off by running the same command again.
|
286 |
+
4. The evaluation results will be saved in `{model_name}/evaluation` as another dict.
|
287 |
+
Some of the metrics are explained as follows:
|
288 |
+
```bash
|
289 |
+
"average_score": The GPT-evaluated captioning score we report in our paper.
|
290 |
+
"accuracy": The classification accuracy we report in our paper, including random choices made by ChatGPT when model outputs are vague or ambiguous and ChatGPT outputs "INVALID".
|
291 |
+
"clean_accuracy": The classification accuracy after removing those "INVALID" outputs.
|
292 |
+
"total_predictions": The number of predictions.
|
293 |
+
"correct_predictions": The number of correct predictions.
|
294 |
+
"invalid_responses": The number of "INVALID" outputs by ChatGPT.
|
295 |
+
|
296 |
+
# Some other statistics for calling OpenAI API
|
297 |
+
"prompt_tokens": The total number of tokens of the prompts for ChatGPT/GPT-4.
|
298 |
+
"completion_tokens": The total number of tokens of the completion results from ChatGPT/GPT-4.
|
299 |
+
"GPT_cost": The API cost of the whole evaluation process, in US Dollars 💵.
|
300 |
+
```
|
301 |
+
5. <b>Open-Step Evaluation.</b> You can also start evaluation immediately after inferencing by passing the `--start_eval` flag and specifying the `--gpt_type`. For example:
|
302 |
+
```bash
|
303 |
+
python pointllm/eval/eval_objaverse.py --model_name RunsenXu/PointLLM_7B_v1.2 --task_type classification --prompt_index 0 --start_eval --gpt_type gpt-4-0613
|
304 |
+
```
|
305 |
+
|
306 |
+
#### Traditional Metric Evaluation
|
307 |
+
1. For the object captioning task, run the following command to evaluate model outputs with traditional metrics including BLEU, ROUGE, METEOR, Sentence-BERT, and SimCSE.
|
308 |
+
```bash
|
309 |
+
python pointllm/eval/traditional_evaluator.py --results_path /path/to/model_captioning_output
|
310 |
+
```
|
311 |
+
2. Note that we recommend not using BLEU, ROUGE, and METEOR for evaluation as they favor short captions and fall short of capturing semantic accuracy and diversity.
|
312 |
+
|
313 |
+
## 📝 TODO List
|
314 |
+
- [x] Add inferencing codes with checkpoints.
|
315 |
+
- [x] Release instruction-following data.
|
316 |
+
- [x] Add training codes.
|
317 |
+
- [x] Add evaluation codes.
|
318 |
+
- [x] Add gradio demo codes.
|
319 |
+
|
320 |
+
Community contributions are welcome!👇 If you need any support, please feel free to open an issue or contact us.
|
321 |
+
- [ ] Support Phi-2 LLM to make PointLLM more accessible to the community.
|
322 |
+
- [ ] Support Chinese LLMs like InternLM.
|
323 |
+
|
324 |
+
## 🔗 Citation
|
325 |
+
|
326 |
+
If you find our work and this codebase helpful, please consider starring this repo 🌟 and cite:
|
327 |
+
|
328 |
+
```bibtex
|
329 |
+
@article{xu2023pointllm,
|
330 |
+
title={PointLLM: Empowering Large Language Models to Understand Point Clouds},
|
331 |
+
author={Xu, Runsen and Wang, Xiaolong and Wang, Tai and Chen, Yilun and Pang, Jiangmiao and Lin, Dahua},
|
332 |
+
journal={arXiv preprint arXiv:2308.16911},
|
333 |
+
year={2023}
|
334 |
+
}
|
335 |
+
```
|
336 |
+
|
337 |
+
## 📄 License
|
338 |
+
<a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/80x15.png" /></a>
|
339 |
+
<br />
|
340 |
+
This work is under the <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.
|
341 |
+
|
342 |
+
## 📚 Related Work
|
343 |
+
Together, Let's make LLM for 3D great!
|
344 |
+
- [Point-Bind & Point-LLM](https://arxiv.org/abs/2309.00615): aligns point clouds with Image-Bind, and leverages ImageBind-LLM to reason multi-modality input without 3D-instruction data training.
|
345 |
+
- [3D-LLM](https://arxiv.org/abs/2307.12981): employs 2D foundation models to encode multi-view images of 3D point clouds.
|
346 |
+
|
347 |
+
|
348 |
+
## 👏 Acknowledgements
|
349 |
+
- [LLaVA](https://github.com/haotian-liu/LLaVA): Our codebase is built upon LLaVA.
|
350 |
+
- [Vicuna](https://github.com/lm-sys/FastChat): We use the Vicuna-7B and Vicuna-13B checkpoints.
|
351 |
+
- [Objaverse](https://objaverse.allenai.org): We use models of the Objaverse dataset for training and evaluation.
|
352 |
+
- [Cap3D](https://github.com/crockwell/Cap3D/): We use the Cap3D captioning data for our data generation.
|
353 |
+
- [ULIP-2](https://github.com/salesforce/ULIP): We use ULIP-2 for pre-training our point cloud encoder.
|
ThirdParty/PointLLM/__init__.py
ADDED
File without changes
|
ThirdParty/PointLLM/pointllm/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
# from .model import PointLLMLlamaForCausalLM
|
ThirdParty/PointLLM/pointllm/conversation.py
ADDED
@@ -0,0 +1,375 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import dataclasses
|
2 |
+
from enum import auto, Enum
|
3 |
+
from typing import List, Tuple
|
4 |
+
|
5 |
+
|
6 |
+
class SeparatorStyle(Enum):
|
7 |
+
"""Different separator style."""
|
8 |
+
SINGLE = auto()
|
9 |
+
TWO = auto()
|
10 |
+
MPT = auto()
|
11 |
+
|
12 |
+
|
13 |
+
@dataclasses.dataclass
|
14 |
+
class Conversation:
|
15 |
+
"""A class that keeps all conversation history."""
|
16 |
+
system: str
|
17 |
+
roles: List[str]
|
18 |
+
messages: List[List[str]]
|
19 |
+
offset: int
|
20 |
+
sep_style: SeparatorStyle = SeparatorStyle.SINGLE
|
21 |
+
sep: str = "###"
|
22 |
+
sep2: str = None
|
23 |
+
version: str = "Unknown"
|
24 |
+
|
25 |
+
skip_next: bool = False
|
26 |
+
|
27 |
+
def reset(self):
|
28 |
+
self.messages = self.messages[:self.offset]
|
29 |
+
|
30 |
+
def get_prompt(self):
|
31 |
+
if self.sep_style == SeparatorStyle.SINGLE:
|
32 |
+
ret = self.system + self.sep
|
33 |
+
for role, message in self.messages:
|
34 |
+
if message:
|
35 |
+
if type(message) is tuple:
|
36 |
+
message, _, _ = message
|
37 |
+
ret += role + ": " + message + self.sep
|
38 |
+
else:
|
39 |
+
ret += role + ":"
|
40 |
+
return ret
|
41 |
+
elif self.sep_style == SeparatorStyle.TWO:
|
42 |
+
seps = [self.sep, self.sep2]
|
43 |
+
ret = self.system + seps[0]
|
44 |
+
for i, (role, message) in enumerate(self.messages):
|
45 |
+
if message:
|
46 |
+
if type(message) is tuple:
|
47 |
+
message, _, _ = message
|
48 |
+
ret += role + ": " + message + seps[i % 2]
|
49 |
+
else:
|
50 |
+
ret += role + ":"
|
51 |
+
return ret
|
52 |
+
if self.sep_style == SeparatorStyle.MPT:
|
53 |
+
ret = self.system + self.sep
|
54 |
+
for role, message in self.messages:
|
55 |
+
if message:
|
56 |
+
if type(message) is tuple:
|
57 |
+
message, _, _ = message
|
58 |
+
ret += role + message + self.sep
|
59 |
+
else:
|
60 |
+
ret += role
|
61 |
+
return ret
|
62 |
+
else:
|
63 |
+
raise ValueError(f"Invalid style: {self.sep_style}")
|
64 |
+
|
65 |
+
def append_message(self, role, message):
|
66 |
+
self.messages.append([role, message])
|
67 |
+
|
68 |
+
def pop_last_none_message(self):
|
69 |
+
# * pop the last message if it's None, this is used for multi-round dialogue
|
70 |
+
if self.messages[-1][1] is None:
|
71 |
+
self.messages.pop()
|
72 |
+
|
73 |
+
def get_images(self, return_pil=False):
|
74 |
+
images = []
|
75 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
76 |
+
if i % 2 == 0:
|
77 |
+
if type(msg) is tuple:
|
78 |
+
import base64
|
79 |
+
from io import BytesIO
|
80 |
+
from PIL import Image
|
81 |
+
msg, image, image_process_mode = msg
|
82 |
+
if image_process_mode == "Pad":
|
83 |
+
def expand2square(pil_img, background_color=(122, 116, 104)):
|
84 |
+
width, height = pil_img.size
|
85 |
+
if width == height:
|
86 |
+
return pil_img
|
87 |
+
elif width > height:
|
88 |
+
result = Image.new(pil_img.mode, (width, width), background_color)
|
89 |
+
result.paste(pil_img, (0, (width - height) // 2))
|
90 |
+
return result
|
91 |
+
else:
|
92 |
+
result = Image.new(pil_img.mode, (height, height), background_color)
|
93 |
+
result.paste(pil_img, ((height - width) // 2, 0))
|
94 |
+
return result
|
95 |
+
image = expand2square(image)
|
96 |
+
elif image_process_mode == "Crop":
|
97 |
+
pass
|
98 |
+
elif image_process_mode == "Resize":
|
99 |
+
image = image.resize((224, 224))
|
100 |
+
else:
|
101 |
+
raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
|
102 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
103 |
+
aspect_ratio = max_hw / min_hw
|
104 |
+
max_len, min_len = 800, 400
|
105 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
106 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
107 |
+
W, H = image.size
|
108 |
+
if H > W:
|
109 |
+
H, W = longest_edge, shortest_edge
|
110 |
+
else:
|
111 |
+
H, W = shortest_edge, longest_edge
|
112 |
+
image = image.resize((W, H))
|
113 |
+
if return_pil:
|
114 |
+
images.append(image)
|
115 |
+
else:
|
116 |
+
buffered = BytesIO()
|
117 |
+
image.save(buffered, format="JPEG")
|
118 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
119 |
+
images.append(img_b64_str)
|
120 |
+
return images
|
121 |
+
|
122 |
+
def to_gradio_chatbot(self):
|
123 |
+
ret = []
|
124 |
+
for i, (role, msg) in enumerate(self.messages[self.offset:]):
|
125 |
+
if i % 2 == 0:
|
126 |
+
if type(msg) is tuple:
|
127 |
+
import base64
|
128 |
+
from io import BytesIO
|
129 |
+
msg, image, image_process_mode = msg
|
130 |
+
max_hw, min_hw = max(image.size), min(image.size)
|
131 |
+
aspect_ratio = max_hw / min_hw
|
132 |
+
max_len, min_len = 800, 400
|
133 |
+
shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
|
134 |
+
longest_edge = int(shortest_edge * aspect_ratio)
|
135 |
+
W, H = image.size
|
136 |
+
if H > W:
|
137 |
+
H, W = longest_edge, shortest_edge
|
138 |
+
else:
|
139 |
+
H, W = shortest_edge, longest_edge
|
140 |
+
image = image.resize((W, H))
|
141 |
+
# image = image.resize((224, 224))
|
142 |
+
buffered = BytesIO()
|
143 |
+
image.save(buffered, format="JPEG")
|
144 |
+
img_b64_str = base64.b64encode(buffered.getvalue()).decode()
|
145 |
+
img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
|
146 |
+
msg = msg.replace('<image>', img_str)
|
147 |
+
ret.append([msg, None])
|
148 |
+
else:
|
149 |
+
ret[-1][-1] = msg
|
150 |
+
return ret
|
151 |
+
|
152 |
+
def copy(self):
|
153 |
+
return Conversation(
|
154 |
+
system=self.system,
|
155 |
+
roles=self.roles,
|
156 |
+
messages=[[x, y] for x, y in self.messages],
|
157 |
+
offset=self.offset,
|
158 |
+
sep_style=self.sep_style,
|
159 |
+
sep=self.sep,
|
160 |
+
sep2=self.sep2)
|
161 |
+
|
162 |
+
def dict(self):
|
163 |
+
if len(self.get_images()) > 0:
|
164 |
+
return {
|
165 |
+
"system": self.system,
|
166 |
+
"roles": self.roles,
|
167 |
+
"messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
|
168 |
+
"offset": self.offset,
|
169 |
+
"sep": self.sep,
|
170 |
+
"sep2": self.sep2,
|
171 |
+
}
|
172 |
+
return {
|
173 |
+
"system": self.system,
|
174 |
+
"roles": self.roles,
|
175 |
+
"messages": self.messages,
|
176 |
+
"offset": self.offset,
|
177 |
+
"sep": self.sep,
|
178 |
+
"sep2": self.sep2,
|
179 |
+
}
|
180 |
+
|
181 |
+
|
182 |
+
conv_v1 = Conversation(
|
183 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
184 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
185 |
+
roles=("Human", "Assistant"),
|
186 |
+
messages=(
|
187 |
+
("Human", "Give three tips for staying healthy."),
|
188 |
+
("Assistant",
|
189 |
+
"Sure, here are three tips for staying healthy:\n"
|
190 |
+
"1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
|
191 |
+
"It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
|
192 |
+
"and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
|
193 |
+
"75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
|
194 |
+
"activities at least two days per week.\n"
|
195 |
+
"2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
|
196 |
+
"vegetables, whole grains, lean proteins, and healthy fats can help support "
|
197 |
+
"your overall health. Try to limit your intake of processed and high-sugar foods, "
|
198 |
+
"and aim to drink plenty of water throughout the day.\n"
|
199 |
+
"3. Get enough sleep: Getting enough quality sleep is essential for your physical "
|
200 |
+
"and mental health. Adults should aim for seven to nine hours of sleep per night. "
|
201 |
+
"Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
|
202 |
+
"help improve the quality of your sleep.")
|
203 |
+
),
|
204 |
+
offset=2,
|
205 |
+
sep_style=SeparatorStyle.SINGLE,
|
206 |
+
sep="###",
|
207 |
+
)
|
208 |
+
|
209 |
+
conv_v1_2 = Conversation(
|
210 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
211 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
212 |
+
roles=("Human", "Assistant"),
|
213 |
+
messages=(
|
214 |
+
("Human", "What are the key differences between renewable and non-renewable energy sources?"),
|
215 |
+
("Assistant",
|
216 |
+
"Renewable energy sources are those that can be replenished naturally in a relatively "
|
217 |
+
"short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
|
218 |
+
"Non-renewable energy sources, on the other hand, are finite and will eventually be "
|
219 |
+
"depleted, such as coal, oil, and natural gas. Here are some key differences between "
|
220 |
+
"renewable and non-renewable energy sources:\n"
|
221 |
+
"1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
|
222 |
+
"energy sources are finite and will eventually run out.\n"
|
223 |
+
"2. Environmental impact: Renewable energy sources have a much lower environmental impact "
|
224 |
+
"than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
|
225 |
+
"and other negative effects.\n"
|
226 |
+
"3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
|
227 |
+
"have lower operational costs than non-renewable sources.\n"
|
228 |
+
"4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
|
229 |
+
"locations than non-renewable sources.\n"
|
230 |
+
"5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
|
231 |
+
"situations and needs, while non-renewable sources are more rigid and inflexible.\n"
|
232 |
+
"6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
|
233 |
+
"non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
|
234 |
+
),
|
235 |
+
offset=2,
|
236 |
+
sep_style=SeparatorStyle.SINGLE,
|
237 |
+
sep="###",
|
238 |
+
)
|
239 |
+
|
240 |
+
conv_vicuna_v1_1 = Conversation(
|
241 |
+
system="A chat between a curious user and an artificial intelligence assistant. "
|
242 |
+
"The assistant gives helpful, detailed, and polite answers to the user's questions.",
|
243 |
+
roles=("USER", "ASSISTANT"),
|
244 |
+
version="v1",
|
245 |
+
messages=(),
|
246 |
+
offset=0,
|
247 |
+
sep_style=SeparatorStyle.TWO,
|
248 |
+
sep=" ",
|
249 |
+
sep2="</s>",
|
250 |
+
)
|
251 |
+
|
252 |
+
conv_mpt = Conversation(
|
253 |
+
system="""<|im_start|>system
|
254 |
+
- You are a helpful language and vision assistant.
|
255 |
+
- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
|
256 |
+
- You should follow the instructions carefully and explain your answers in detail.""",
|
257 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
258 |
+
version="mpt",
|
259 |
+
messages=(),
|
260 |
+
offset=0,
|
261 |
+
sep_style=SeparatorStyle.MPT,
|
262 |
+
sep="<|im_end|>",
|
263 |
+
)
|
264 |
+
|
265 |
+
conv_mpt_text = Conversation(
|
266 |
+
system="""<|im_start|>system
|
267 |
+
- You are a helpful assistant chatbot trained by MosaicML.
|
268 |
+
- You answer questions.
|
269 |
+
- You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
|
270 |
+
- You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
|
271 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
272 |
+
version="mpt",
|
273 |
+
messages=(),
|
274 |
+
offset=0,
|
275 |
+
sep_style=SeparatorStyle.MPT,
|
276 |
+
sep="<|im_end|>",
|
277 |
+
)
|
278 |
+
|
279 |
+
conv_bair_v1 = Conversation(
|
280 |
+
system="BEGINNING OF CONVERSATION:",
|
281 |
+
roles=("USER", "GPT"),
|
282 |
+
messages=(),
|
283 |
+
offset=0,
|
284 |
+
sep_style=SeparatorStyle.TWO,
|
285 |
+
sep=" ",
|
286 |
+
sep2="</s>",
|
287 |
+
)
|
288 |
+
|
289 |
+
simple_conv = Conversation(
|
290 |
+
system="A chat between a curious human and an artificial intelligence assistant. "
|
291 |
+
"The assistant gives helpful, detailed, and polite answers to the human's questions.",
|
292 |
+
roles=("Human", "Assistant"),
|
293 |
+
messages=(
|
294 |
+
("Human", "Hi!"),
|
295 |
+
("Assistant", "Hi there! How can I help you today?")
|
296 |
+
),
|
297 |
+
offset=2,
|
298 |
+
sep_style=SeparatorStyle.SINGLE,
|
299 |
+
sep="###",
|
300 |
+
)
|
301 |
+
|
302 |
+
simple_conv_multimodal = Conversation(
|
303 |
+
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
304 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
305 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
306 |
+
roles=("Human", "Assistant"),
|
307 |
+
messages=(
|
308 |
+
("Human", "Hi!"),
|
309 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
310 |
+
),
|
311 |
+
offset=2,
|
312 |
+
sep_style=SeparatorStyle.SINGLE,
|
313 |
+
sep="###",
|
314 |
+
)
|
315 |
+
|
316 |
+
simple_conv_mpt_multimodal = Conversation(
|
317 |
+
system="""<|im_start|>system
|
318 |
+
- You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
|
319 |
+
- You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
|
320 |
+
- You should follow the instructions carefully and explain your answers in detail.""",
|
321 |
+
roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
|
322 |
+
version="mpt",
|
323 |
+
messages=(),
|
324 |
+
offset=0,
|
325 |
+
sep_style=SeparatorStyle.MPT,
|
326 |
+
sep="<|im_end|>",
|
327 |
+
)
|
328 |
+
|
329 |
+
simple_conv_legacy = Conversation(
|
330 |
+
system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
|
331 |
+
"You are designed to assist human with a variety of tasks using natural language."
|
332 |
+
"Follow the instructions carefully.",
|
333 |
+
roles=("Human", "Assistant"),
|
334 |
+
messages=(
|
335 |
+
("Human", "Hi!\n\n### Response:"),
|
336 |
+
("Assistant", "Hi there! How can I help you today?\n")
|
337 |
+
),
|
338 |
+
offset=2,
|
339 |
+
sep_style=SeparatorStyle.SINGLE,
|
340 |
+
sep="###",
|
341 |
+
)
|
342 |
+
|
343 |
+
conv_llava_v1 = Conversation(
|
344 |
+
system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
|
345 |
+
"You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
|
346 |
+
"Follow the instructions carefully and explain your answers in detail.",
|
347 |
+
roles=("USER", "ASSISTANT"),
|
348 |
+
version="v1",
|
349 |
+
messages=(),
|
350 |
+
offset=0,
|
351 |
+
sep_style=SeparatorStyle.TWO,
|
352 |
+
sep=" ",
|
353 |
+
sep2="</s>",
|
354 |
+
)
|
355 |
+
|
356 |
+
default_conversation = conv_v1_2
|
357 |
+
conv_templates = {
|
358 |
+
"default": conv_v1_2,
|
359 |
+
"simple": simple_conv,
|
360 |
+
"simple_legacy": simple_conv_legacy,
|
361 |
+
"multimodal": simple_conv_multimodal,
|
362 |
+
"mpt_multimodal": simple_conv_mpt_multimodal,
|
363 |
+
"llava_v1": conv_llava_v1,
|
364 |
+
|
365 |
+
# fastchat
|
366 |
+
"v1": conv_v1_2,
|
367 |
+
"bair_v1": conv_bair_v1,
|
368 |
+
"vicuna_v1_1": conv_vicuna_v1_1,
|
369 |
+
"mpt": conv_mpt,
|
370 |
+
"mpt_text": conv_mpt_text,
|
371 |
+
}
|
372 |
+
|
373 |
+
|
374 |
+
if __name__ == "__main__":
|
375 |
+
print(default_conversation.get_prompt())
|
ThirdParty/PointLLM/pointllm/data/__init__.py
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
from .utils import load_objaverse_point_cloud, pc_norm, farthest_point_sample
|
2 |
+
from .object_point_dataset import ObjectPointCloudDataset, make_object_point_data_module
|
3 |
+
from .modelnet import ModelNet
|
ThirdParty/PointLLM/pointllm/data/modelnet.py
ADDED
@@ -0,0 +1,147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import torch
|
3 |
+
import numpy as np
|
4 |
+
import pickle
|
5 |
+
from torch.utils.data import Dataset
|
6 |
+
from pointllm.utils import *
|
7 |
+
from pointllm.data.utils import *
|
8 |
+
|
9 |
+
class ModelNet(Dataset):
|
10 |
+
def __init__(self, config_path, split, subset_nums=-1, use_color=False):
|
11 |
+
"""
|
12 |
+
Args:
|
13 |
+
data_args:
|
14 |
+
split: train or test
|
15 |
+
"""
|
16 |
+
super(ModelNet, self).__init__()
|
17 |
+
|
18 |
+
if config_path is None:
|
19 |
+
# * use the default config file in the same dir
|
20 |
+
config_path = os.path.join(os.path.dirname(__file__), "modelnet_config", "ModelNet40.yaml")
|
21 |
+
|
22 |
+
config = cfg_from_yaml_file(config_path)
|
23 |
+
# * check data path
|
24 |
+
self.root = config["DATA_PATH"]
|
25 |
+
|
26 |
+
if not os.path.exists(self.root):
|
27 |
+
print(f"Data path {self.root} does not exist. Please check your data path.")
|
28 |
+
exit()
|
29 |
+
|
30 |
+
self.npoints = config.npoints
|
31 |
+
self.num_category = config.NUM_CATEGORY # * should be 40
|
32 |
+
self.random_sample = config.random_sampling
|
33 |
+
self.use_height = config.use_height
|
34 |
+
self.use_normals = config.USE_NORMALS
|
35 |
+
self.subset_nums = subset_nums
|
36 |
+
self.normalize_pc = True
|
37 |
+
self.use_color = use_color
|
38 |
+
|
39 |
+
if self.use_height or self.use_normals:
|
40 |
+
print(f"Warning: Usually we don't use height or normals for shapenet but use_height: {self.use_height} and \
|
41 |
+
use_normals: {self.use_normals}.")
|
42 |
+
|
43 |
+
self.split = split
|
44 |
+
assert (self.split == 'train' or self.split == 'test')
|
45 |
+
|
46 |
+
self.catfile = os.path.join(os.path.dirname(__file__), "modelnet_config", 'modelnet40_shape_names_modified.txt')
|
47 |
+
|
48 |
+
# "tv_stand" -> "tv stand"
|
49 |
+
self.categories = [line.rstrip() for line in open(self.catfile)] # * list of category names
|
50 |
+
|
51 |
+
self.save_path = os.path.join(self.root,
|
52 |
+
'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, self.split, self.npoints))
|
53 |
+
|
54 |
+
print('Load processed data from %s...' % self.save_path)
|
55 |
+
with open(self.save_path, 'rb') as f:
|
56 |
+
self.list_of_points, self.list_of_labels = pickle.load(f) # * ndarray of N, C: (8192, 6) (xyz and normals)
|
57 |
+
|
58 |
+
if self.subset_nums > 0:
|
59 |
+
# * set random seed
|
60 |
+
import random
|
61 |
+
random.seed(0)
|
62 |
+
# * random choose subset_nums
|
63 |
+
idxs = random.sample(range(len(self.list_of_labels)), self.subset_nums)
|
64 |
+
self.list_of_labels = [self.list_of_labels[idx] for idx in idxs]
|
65 |
+
self.list_of_points = [self.list_of_points[idx] for idx in idxs]
|
66 |
+
|
67 |
+
# * print len
|
68 |
+
print(f"Load {len(self.list_of_points)} data from {self.save_path}.")
|
69 |
+
|
70 |
+
def __len__(self):
|
71 |
+
return len(self.list_of_labels)
|
72 |
+
|
73 |
+
def _get_item(self, index):
|
74 |
+
point_set, label = self.list_of_points[index], self.list_of_labels[index]
|
75 |
+
|
76 |
+
if self.npoints < point_set.shape[0]:
|
77 |
+
if self.random_sample:
|
78 |
+
# * random sample
|
79 |
+
point_set = point_set[np.random.choice(point_set.shape[0], self.npoints, replace=False)]
|
80 |
+
else:
|
81 |
+
point_set = farthest_point_sample(point_set, self.npoints)
|
82 |
+
|
83 |
+
point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
|
84 |
+
if not self.use_normals:
|
85 |
+
point_set = point_set[:, 0:3]
|
86 |
+
|
87 |
+
if self.use_height:
|
88 |
+
self.gravity_dim = 1
|
89 |
+
height_array = point_set[:, self.gravity_dim:self.gravity_dim + 1] - point_set[:,
|
90 |
+
self.gravity_dim:self.gravity_dim + 1].min()
|
91 |
+
point_set = np.concatenate((point_set, height_array), axis=1)
|
92 |
+
|
93 |
+
point_set = np.concatenate((point_set, np.zeros_like(point_set)), axis=-1) if self.use_color else point_set
|
94 |
+
|
95 |
+
return point_set, label.item() # * ndarray, int
|
96 |
+
|
97 |
+
def pc_norm(self, pc):
|
98 |
+
""" pc: NxC, return NxC """
|
99 |
+
xyz = pc[:, :3]
|
100 |
+
other_feature = pc[:, 3:]
|
101 |
+
|
102 |
+
centroid = np.mean(xyz, axis=0)
|
103 |
+
xyz = xyz - centroid
|
104 |
+
m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
|
105 |
+
xyz = xyz / m
|
106 |
+
|
107 |
+
pc = np.concatenate((xyz, other_feature), axis=1)
|
108 |
+
return pc
|
109 |
+
|
110 |
+
def __getitem__(self, index):
|
111 |
+
points, label = self._get_item(index)
|
112 |
+
pt_idxs = np.arange(0, points.shape[0]) # 2048
|
113 |
+
if self.split == 'train':
|
114 |
+
np.random.shuffle(pt_idxs)
|
115 |
+
current_points = points[pt_idxs].copy()
|
116 |
+
|
117 |
+
if self.normalize_pc:
|
118 |
+
# * modelnet point cloud is already normalized
|
119 |
+
current_points = self.pc_norm(current_points)
|
120 |
+
|
121 |
+
current_points = torch.from_numpy(current_points).float() # * N, C tensors
|
122 |
+
label_name = self.categories[int(label)]
|
123 |
+
|
124 |
+
data_dict = {
|
125 |
+
"indice": index, # * int
|
126 |
+
"point_clouds": current_points, # * tensor of N, C
|
127 |
+
"labels": label, # * int
|
128 |
+
"label_names": label_name # * str
|
129 |
+
}
|
130 |
+
|
131 |
+
return data_dict
|
132 |
+
|
133 |
+
if __name__ == '__main__':
|
134 |
+
import argparse
|
135 |
+
|
136 |
+
parser = argparse.ArgumentParser(description='ModelNet Dataset')
|
137 |
+
|
138 |
+
parser.add_argument("--config_path", type=str, default=None, help="config file path.")
|
139 |
+
parser.add_argument("--split", type=str, default="test", help="train or test.")
|
140 |
+
parser.add_argument("--subset_nums", type=int, default=200)
|
141 |
+
|
142 |
+
args = parser.parse_args()
|
143 |
+
|
144 |
+
dataset = ModelNet(config_path=args.config_path, split=args.split, subset_nums=args.subset_nums)
|
145 |
+
|
146 |
+
# * get the first item
|
147 |
+
print(dataset[0])
|
ThirdParty/PointLLM/pointllm/data/modelnet_config/ModelNet40.yaml
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
NAME: ModelNet
|
2 |
+
DATA_PATH: data/modelnet40_data
|
3 |
+
NUM_CATEGORY: 40
|
4 |
+
USE_NORMALS: FALSE
|
5 |
+
npoints: 8192
|
6 |
+
random_sampling: TRUE
|
7 |
+
use_height: FALSE
|
8 |
+
use_normals: FALSE
|
ThirdParty/PointLLM/pointllm/data/object_point_dataset.py
ADDED
@@ -0,0 +1,250 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import json
|
3 |
+
import torch
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
import copy
|
7 |
+
import transformers
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
|
10 |
+
from .utils import *
|
11 |
+
|
12 |
+
|
13 |
+
def make_object_point_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
|
14 |
+
"""Make dataset and collator for Joint3Ddataset with text and point cloud data."""
|
15 |
+
"""Initialize datasets."""
|
16 |
+
|
17 |
+
data_collator = DataCollatorForPointTextDataset(tokenizer=tokenizer)
|
18 |
+
if data_args.split_train_val:
|
19 |
+
print("Loading training datasets.")
|
20 |
+
train_dataset = ObjectPointCloudDataset(
|
21 |
+
split='train',
|
22 |
+
data_path=data_args.data_path,
|
23 |
+
anno_path=data_args.anno_path,
|
24 |
+
pointnum=data_args.pointnum,
|
25 |
+
conversation_types=data_args.conversation_types,
|
26 |
+
tokenizer=tokenizer,
|
27 |
+
use_color=data_args.use_color,
|
28 |
+
data_args=data_args
|
29 |
+
)
|
30 |
+
print("Done!")
|
31 |
+
if data_args.data_debug_num > 0:
|
32 |
+
print('Debug mode, using training set as val set.')
|
33 |
+
val_dataset = train_dataset
|
34 |
+
else:
|
35 |
+
# * make a val dataset
|
36 |
+
print("Loading validation datasets.")
|
37 |
+
val_dataset = ObjectPointCloudDataset(
|
38 |
+
split='val', # * load train split
|
39 |
+
data_path=data_args.data_path,
|
40 |
+
anno_path=data_args.anno_path,
|
41 |
+
pointnum=data_args.pointnum,
|
42 |
+
conversation_types=data_args.conversation_types,
|
43 |
+
tokenizer=tokenizer,
|
44 |
+
use_color=data_args.use_color,
|
45 |
+
data_args=data_args
|
46 |
+
)
|
47 |
+
return dict(train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator)
|
48 |
+
else:
|
49 |
+
# * use all data as training data
|
50 |
+
train_dataset = ObjectPointCloudDataset(
|
51 |
+
split='train',
|
52 |
+
data_path=data_args.data_path,
|
53 |
+
anno_path=data_args.anno_path,
|
54 |
+
pointnum=data_args.pointnum,
|
55 |
+
conversation_types=data_args.conversation_types,
|
56 |
+
use_color=data_args.use_color,
|
57 |
+
tokenizer=tokenizer,
|
58 |
+
data_args=data_args
|
59 |
+
)
|
60 |
+
return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
|
61 |
+
|
62 |
+
class ObjectPointCloudDataset(Dataset):
|
63 |
+
"""Dataset utilities for objaverse."""
|
64 |
+
def __init__(self,
|
65 |
+
data_path=None,
|
66 |
+
anno_path=None,
|
67 |
+
tokenizer=None,
|
68 |
+
pointnum=8192,
|
69 |
+
split='train',
|
70 |
+
conversation_types=None, # * default is simple_des, used for stage1 pre-train
|
71 |
+
use_color=True,
|
72 |
+
data_args=None):
|
73 |
+
|
74 |
+
"""
|
75 |
+
split: only considered when data_args.split_train_val is True.
|
76 |
+
conversation_types: tuple, used to filter the data, default is ('simple_description'), other types is:
|
77 |
+
"detailed_description", "single_round", "multi_round".
|
78 |
+
tokenizer: load point clouds only if None
|
79 |
+
"""
|
80 |
+
super(ObjectPointCloudDataset, self).__init__()
|
81 |
+
|
82 |
+
"""Initialize dataset with object point clouds and text"""
|
83 |
+
self.data_path = data_path
|
84 |
+
self.anno_path = anno_path
|
85 |
+
self.tokenizer = tokenizer
|
86 |
+
self.split = split
|
87 |
+
if conversation_types is None:
|
88 |
+
self.conversation_types = ("simple_description",)
|
89 |
+
else:
|
90 |
+
self.conversation_types = conversation_types
|
91 |
+
|
92 |
+
self.data_args = data_args
|
93 |
+
self.normalize_pc = True
|
94 |
+
self.use_color = use_color
|
95 |
+
|
96 |
+
self.pointnum = pointnum
|
97 |
+
self.point_backbone_config = data_args.point_backbone_config if data_args is not None else None
|
98 |
+
self.point_indicator = '<point>'
|
99 |
+
|
100 |
+
# Load the data list from JSON
|
101 |
+
print(f"Loading anno file from {anno_path}.")
|
102 |
+
with open(anno_path, "r") as json_file:
|
103 |
+
self.list_data_dict = json.load(json_file)
|
104 |
+
|
105 |
+
# * print the conversations_type
|
106 |
+
print(f"Using conversation_type: {self.conversation_types}")
|
107 |
+
# * print before filtering
|
108 |
+
print(f"Before filtering, the dataset size is: {len(self.list_data_dict)}.")
|
109 |
+
|
110 |
+
# * iterate the list and filter
|
111 |
+
# * these two ids have corrupted colored point files, so filter them when use_color is True
|
112 |
+
filter_ids = ['6760e543e1d645d5aaacd3803bcae524', 'b91c0711149d460a8004f9c06d3b7f38'] if self.use_color else []
|
113 |
+
|
114 |
+
# Iterate the list, filter those "conversation_type" not in self.conversation_types
|
115 |
+
self.list_data_dict = [
|
116 |
+
data for data in self.list_data_dict
|
117 |
+
if data.get('conversation_type', 'simple_description') in self.conversation_types
|
118 |
+
and data.get('object_id') not in filter_ids
|
119 |
+
]
|
120 |
+
|
121 |
+
# * print after filtering
|
122 |
+
print(f"After filtering, the dataset size is: {len(self.list_data_dict)}.")
|
123 |
+
# * print the size of different conversation_type
|
124 |
+
for conversation_type in self.conversation_types:
|
125 |
+
print(f"Number of {conversation_type}: {len([data for data in self.list_data_dict if data.get('conversation_type', 'simple_description') == conversation_type])}")
|
126 |
+
|
127 |
+
if self.data_args is not None and self.data_args.data_debug_num > 0:
|
128 |
+
self.list_data_dict = self.list_data_dict[:self.data_args.data_debug_num]
|
129 |
+
# * print all the scan_id in debug mode, not using for loop
|
130 |
+
print('Debug mode, using: ' + ' '.join([data['object_id'] for data in self.list_data_dict]))
|
131 |
+
elif self.data_args is not None and self.data_args.split_train_val:
|
132 |
+
# * split train and val with 9:1 ratios
|
133 |
+
if self.split == 'train':
|
134 |
+
self.list_data_dict = self.list_data_dict[:int(self.data_args.split_ratio * len(self.list_data_dict))]
|
135 |
+
print(f"Train set size: {len(self.list_data_dict)}")
|
136 |
+
else:
|
137 |
+
self.list_data_dict = self.list_data_dict[int(self.data_args.split_ratio * len(self.list_data_dict)):]
|
138 |
+
print(f"Val set size: {len(self.list_data_dict)}")
|
139 |
+
|
140 |
+
def _load_point_cloud(self, object_id, type='objaverse'):
|
141 |
+
if type == 'objaverse':
|
142 |
+
return self._load_objaverse_point_cloud(object_id)
|
143 |
+
|
144 |
+
def _load_objaverse_point_cloud(self, object_id):
|
145 |
+
filename = f"{object_id}_{self.pointnum}.npy"
|
146 |
+
point_cloud = np.load(os.path.join(self.data_path, filename))
|
147 |
+
|
148 |
+
if not self.use_color:
|
149 |
+
point_cloud = point_cloud[:, :3]
|
150 |
+
|
151 |
+
return point_cloud
|
152 |
+
|
153 |
+
def pc_norm(self, pc):
|
154 |
+
""" pc: NxC, return NxC """
|
155 |
+
xyz = pc[:, :3]
|
156 |
+
other_feature = pc[:, 3:]
|
157 |
+
|
158 |
+
centroid = np.mean(xyz, axis=0)
|
159 |
+
xyz = xyz - centroid
|
160 |
+
m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
|
161 |
+
xyz = xyz / m
|
162 |
+
|
163 |
+
pc = np.concatenate((xyz, other_feature), axis=1)
|
164 |
+
return pc
|
165 |
+
|
166 |
+
def __getitem__(self, index):
|
167 |
+
sources = self.list_data_dict[index]
|
168 |
+
if isinstance(index, int):
|
169 |
+
sources = [sources]
|
170 |
+
assert len(sources) == 1, "sources should be a list"
|
171 |
+
if self.point_indicator in sources[0]['conversations'][0]['value']:
|
172 |
+
|
173 |
+
object_id = self.list_data_dict[index]['object_id']
|
174 |
+
|
175 |
+
# Point cloud representation
|
176 |
+
point_cloud = self._load_point_cloud(object_id) # * N, C
|
177 |
+
if self.normalize_pc:
|
178 |
+
point_cloud = self.pc_norm(point_cloud) # * need to norm since point encoder is norm
|
179 |
+
|
180 |
+
if self.tokenizer is None:
|
181 |
+
data_dict = dict(
|
182 |
+
point_clouds=torch.from_numpy(point_cloud.astype(np.float32)),
|
183 |
+
object_ids=object_id
|
184 |
+
)
|
185 |
+
return data_dict
|
186 |
+
|
187 |
+
sources = preprocess_multimodal_point_cloud(
|
188 |
+
copy.deepcopy([e["conversations"] for e in sources]), self.point_backbone_config, point_indicator=self.point_indicator)
|
189 |
+
else:
|
190 |
+
sources = copy.deepcopy([e["conversations"] for e in sources])
|
191 |
+
|
192 |
+
data_dict = preprocess_v1(
|
193 |
+
sources,
|
194 |
+
self.tokenizer)
|
195 |
+
|
196 |
+
if isinstance(index, int):
|
197 |
+
data_dict = dict(input_ids=data_dict["input_ids"][0],
|
198 |
+
labels=data_dict["labels"][0])
|
199 |
+
|
200 |
+
# point exist in the data
|
201 |
+
if self.point_indicator in self.list_data_dict[index]['conversations'][0]['value']:
|
202 |
+
data_dict['point_clouds'] = torch.from_numpy(point_cloud.astype(np.float32))
|
203 |
+
|
204 |
+
return data_dict
|
205 |
+
|
206 |
+
def __len__(self):
|
207 |
+
"""Return number of utterances."""
|
208 |
+
return len(self.list_data_dict)
|
209 |
+
|
210 |
+
if __name__ == '__main__':
|
211 |
+
import argparse
|
212 |
+
parser = argparse.ArgumentParser()
|
213 |
+
|
214 |
+
parser.add_argument("--data_path", default="data/objaverse_data", type=str,
|
215 |
+
help="Path to the data directory.")
|
216 |
+
parser.add_argument("--anno_path", default=None, type=str, required=True,
|
217 |
+
help="Path to the annotation file.")
|
218 |
+
parser.add_argument("--split", default='train', type=str,
|
219 |
+
help="Whether to use the train or validation dataset.")
|
220 |
+
parser.add_argument("--pointnum", default=8192, type=int,
|
221 |
+
help="Number of points in the point cloud.")
|
222 |
+
parser.add_argument("--data_debug_num", default=0, type=int,
|
223 |
+
help="Number of data to debug with.")
|
224 |
+
parser.add_argument("--split_train_val", default=False, type=bool,
|
225 |
+
help="Whether to split the dataset into training and validation.")
|
226 |
+
parser.add_argument("--split_ratio", default=0.9, type=float,
|
227 |
+
help="The ratio of training to validation data.")
|
228 |
+
parser.add_argument("--tokenizer_path", default=None, type=str, required=True,
|
229 |
+
help="Path to the tokenizer config file.")
|
230 |
+
|
231 |
+
args = parser.parse_args()
|
232 |
+
|
233 |
+
# Initialize tokenizer
|
234 |
+
tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_path)
|
235 |
+
|
236 |
+
args.point_backbone_config = None
|
237 |
+
|
238 |
+
# Initialize dataset
|
239 |
+
dataset = ObjectPointCloudDataset(
|
240 |
+
data_path=args.data_path,
|
241 |
+
anno_path=args.anno_path,
|
242 |
+
pointnum=args.pointnum,
|
243 |
+
split=args.split,
|
244 |
+
tokenizer=tokenizer,
|
245 |
+
data_args=args
|
246 |
+
)
|
247 |
+
|
248 |
+
# Example usage
|
249 |
+
print(f'Dataset length: {len(dataset)}')
|
250 |
+
|
ThirdParty/PointLLM/pointllm/data/utils.py
ADDED
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict, defaultdict
|
2 |
+
|
3 |
+
import transformers
|
4 |
+
from pointllm import conversation as conversation_lib
|
5 |
+
from dataclasses import dataclass
|
6 |
+
from typing import Optional, Dict, Sequence
|
7 |
+
import torch
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import os
|
11 |
+
|
12 |
+
IGNORE_INDEX = -100
|
13 |
+
|
14 |
+
# * Sample Usage:
|
15 |
+
# * from utils import LRUCache
|
16 |
+
# * cache = LRUCache(capacity, max_access_count)
|
17 |
+
# if self.cache is None:
|
18 |
+
# info_data = self.multiview_scannet[info_index]
|
19 |
+
# else:
|
20 |
+
# info_data = self.cache.get(info_index)
|
21 |
+
# if info_data is None or self.cache.get_access_count(info_index) >= self.cache.max_access_count:
|
22 |
+
# # If not in cache, or accessed max_access_count times, load it and put it in cache
|
23 |
+
# info_data = self.multiview_scannet[info_index]
|
24 |
+
# self.cache.put(info_index, info_data)
|
25 |
+
# self.cache.reset_access_count(info_index)
|
26 |
+
|
27 |
+
class LRUCache:
|
28 |
+
def __init__(self, capacity, max_access_count):
|
29 |
+
self.cache = OrderedDict()
|
30 |
+
self.access_count = defaultdict(int)
|
31 |
+
self.capacity = capacity
|
32 |
+
self.max_access_count = max_access_count
|
33 |
+
|
34 |
+
def get(self, key):
|
35 |
+
if key not in self.cache:
|
36 |
+
return None
|
37 |
+
value = self.cache.pop(key)
|
38 |
+
self.cache[key] = value # Put key as the newest one
|
39 |
+
self.access_count[key] += 1
|
40 |
+
return value
|
41 |
+
|
42 |
+
def put(self, key, value):
|
43 |
+
if key in self.cache: # Update the value and put it as newest
|
44 |
+
self.cache.pop(key)
|
45 |
+
elif len(self.cache) == self.capacity: # If cache is full
|
46 |
+
oldest_key = next(iter(self.cache))
|
47 |
+
self.cache.popitem(last=False) # Remove oldest item
|
48 |
+
del self.access_count[oldest_key] # Remove the corresponding access count
|
49 |
+
self.cache[key] = value
|
50 |
+
self.access_count[key] = 1
|
51 |
+
|
52 |
+
def get_access_count(self, key):
|
53 |
+
return self.access_count.get(key, 0)
|
54 |
+
|
55 |
+
def reset_access_count(self, key):
|
56 |
+
self.access_count[key] = 0
|
57 |
+
|
58 |
+
|
59 |
+
def preprocess_v1(
|
60 |
+
sources,
|
61 |
+
tokenizer: transformers.PreTrainedTokenizer,
|
62 |
+
) -> Dict:
|
63 |
+
conv = conversation_lib.default_conversation.copy()
|
64 |
+
roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
|
65 |
+
|
66 |
+
# Apply prompt templates
|
67 |
+
conversations = []
|
68 |
+
for i, source in enumerate(sources):
|
69 |
+
if roles[source[0]["from"]] != conv.roles[0]:
|
70 |
+
# Skip the first one if it is not from human
|
71 |
+
source = source[1:]
|
72 |
+
|
73 |
+
conv.messages = []
|
74 |
+
for j, sentence in enumerate(source):
|
75 |
+
role = roles[sentence["from"]]
|
76 |
+
assert role == conv.roles[j % 2], f"{i}"
|
77 |
+
conv.append_message(role, sentence["value"])
|
78 |
+
conversations.append(conv.get_prompt())
|
79 |
+
|
80 |
+
# Tokenize conversations
|
81 |
+
input_ids = tokenizer(
|
82 |
+
conversations,
|
83 |
+
return_tensors="pt",
|
84 |
+
padding="longest",
|
85 |
+
max_length=tokenizer.model_max_length,
|
86 |
+
truncation=True,
|
87 |
+
).input_ids
|
88 |
+
targets = input_ids.clone()
|
89 |
+
|
90 |
+
assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
|
91 |
+
|
92 |
+
# Mask targets
|
93 |
+
sep = conv.sep + conv.roles[1] + ": "
|
94 |
+
for conversation, target in zip(conversations, targets):
|
95 |
+
total_len = int(target.ne(tokenizer.pad_token_id).sum())
|
96 |
+
|
97 |
+
rounds = conversation.split(conv.sep2)
|
98 |
+
cur_len = 1
|
99 |
+
target[:cur_len] = IGNORE_INDEX
|
100 |
+
for i, rou in enumerate(rounds):
|
101 |
+
if rou == "":
|
102 |
+
break
|
103 |
+
|
104 |
+
parts = rou.split(sep)
|
105 |
+
if len(parts) != 2: # * can handle padded tokens
|
106 |
+
break
|
107 |
+
parts[0] += sep
|
108 |
+
round_len = len(tokenizer(rou).input_ids)
|
109 |
+
instruction_len = len(tokenizer(parts[0]).input_ids) - 2
|
110 |
+
|
111 |
+
target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
|
112 |
+
|
113 |
+
cur_len += round_len
|
114 |
+
target[cur_len:] = IGNORE_INDEX # * this is necessary for padded tokens
|
115 |
+
|
116 |
+
if cur_len < tokenizer.model_max_length:
|
117 |
+
if cur_len != total_len: # * unk tokens in the dialogue will cause this.
|
118 |
+
target[:] = IGNORE_INDEX
|
119 |
+
print(
|
120 |
+
f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
|
121 |
+
f" (ignored)"
|
122 |
+
)
|
123 |
+
|
124 |
+
return dict(
|
125 |
+
input_ids=input_ids,
|
126 |
+
labels=targets,
|
127 |
+
)
|
128 |
+
|
129 |
+
def preprocess_multimodal_point_cloud(
|
130 |
+
sources: Sequence[str],
|
131 |
+
point_backbone_config: dict,
|
132 |
+
point_indicator: str = "<point>",
|
133 |
+
) -> Dict:
|
134 |
+
point_token_len = point_backbone_config['point_token_len']
|
135 |
+
default_point_patch_token = point_backbone_config['default_point_patch_token']
|
136 |
+
|
137 |
+
for source in sources:
|
138 |
+
for sentence in source:
|
139 |
+
replace_token = default_point_patch_token * point_token_len
|
140 |
+
if point_backbone_config['mm_use_point_start_end']:
|
141 |
+
replace_token = point_backbone_config['default_point_start_token']+ replace_token + point_backbone_config['default_point_end_token']
|
142 |
+
sentence["value"] = sentence["value"].replace(point_indicator, replace_token)
|
143 |
+
|
144 |
+
return sources
|
145 |
+
|
146 |
+
def pc_norm(pc):
|
147 |
+
""" pc: NxC, return NxC """
|
148 |
+
xyz = pc[:, :3]
|
149 |
+
other_feature = pc[:, 3:]
|
150 |
+
|
151 |
+
centroid = np.mean(xyz, axis=0)
|
152 |
+
xyz = xyz - centroid
|
153 |
+
m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
|
154 |
+
xyz = xyz / m
|
155 |
+
|
156 |
+
pc = np.concatenate((xyz, other_feature), axis=1)
|
157 |
+
return pc
|
158 |
+
|
159 |
+
def load_objaverse_point_cloud(data_path, object_id, pointnum=8192, use_color=False):
|
160 |
+
filename = f"{object_id}_{pointnum}.npy"
|
161 |
+
point_cloud = np.load(os.path.join(data_path, filename))
|
162 |
+
|
163 |
+
# * normalize
|
164 |
+
point_cloud = pc_norm(point_cloud)
|
165 |
+
|
166 |
+
if not use_color:
|
167 |
+
point_cloud = point_cloud[:, :3]
|
168 |
+
|
169 |
+
return point_cloud
|
170 |
+
|
171 |
+
@dataclass
|
172 |
+
class DataCollatorForPointTextDataset(object):
|
173 |
+
"""Collate examples for mixed dataset with text and point cloud data."""
|
174 |
+
|
175 |
+
tokenizer: transformers.PreTrainedTokenizer
|
176 |
+
|
177 |
+
def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
|
178 |
+
input_ids, labels = tuple([instance[key] for instance in instances]
|
179 |
+
for key in ("input_ids", "labels"))
|
180 |
+
input_ids = torch.nn.utils.rnn.pad_sequence(
|
181 |
+
input_ids,
|
182 |
+
batch_first=True,
|
183 |
+
padding_value=self.tokenizer.pad_token_id)
|
184 |
+
labels = torch.nn.utils.rnn.pad_sequence(labels,
|
185 |
+
batch_first=True,
|
186 |
+
padding_value=IGNORE_INDEX)
|
187 |
+
batch = dict(
|
188 |
+
input_ids=input_ids,
|
189 |
+
labels=labels,
|
190 |
+
attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
|
191 |
+
)
|
192 |
+
|
193 |
+
if 'point_clouds' in instances[0]:
|
194 |
+
point_clouds = [instance['point_clouds'] for instance in instances]
|
195 |
+
if all(x is not None and x.shape == point_clouds[0].shape for x in point_clouds): # * point_clouds have different shapes
|
196 |
+
batch['point_clouds'] = torch.stack(point_clouds)
|
197 |
+
else:
|
198 |
+
batch['point_clouds'] = point_clouds # * return as lists
|
199 |
+
|
200 |
+
return batch
|
201 |
+
|
202 |
+
def farthest_point_sample(point, npoint):
|
203 |
+
"""
|
204 |
+
Input:
|
205 |
+
xyz: pointcloud data, [N, D]
|
206 |
+
npoint: number of samples
|
207 |
+
Return:
|
208 |
+
centroids: sampled pointcloud index, [npoint, D]
|
209 |
+
"""
|
210 |
+
N, D = point.shape
|
211 |
+
xyz = point[:,:3]
|
212 |
+
centroids = np.zeros((npoint,))
|
213 |
+
distance = np.ones((N,)) * 1e10
|
214 |
+
farthest = np.random.randint(0, N)
|
215 |
+
for i in range(npoint):
|
216 |
+
centroids[i] = farthest
|
217 |
+
centroid = xyz[farthest, :]
|
218 |
+
dist = np.sum((xyz - centroid) ** 2, -1)
|
219 |
+
mask = dist < distance
|
220 |
+
distance[mask] = dist[mask]
|
221 |
+
farthest = np.argmax(distance, -1)
|
222 |
+
point = point[centroids.astype(np.int32)]
|
223 |
+
return point
|
224 |
+
|
225 |
+
def pc_normalize(pc):
|
226 |
+
"""
|
227 |
+
pc: Nx3 array
|
228 |
+
This functions normalizes a point cloud to fit within a unit sphere.
|
229 |
+
It first calculates the centroid of the point cloud and then subtracts
|
230 |
+
it from all points before scaling all points to fit within a unit sphere.
|
231 |
+
"""
|
232 |
+
centroid = np.mean(pc, axis=0)
|
233 |
+
pc = pc - centroid
|
234 |
+
m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
|
235 |
+
pc = pc / m
|
236 |
+
return pc
|
ThirdParty/PointLLM/pointllm/eval/PointLLM_chat.py
ADDED
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from pointllm.conversation import conv_templates, SeparatorStyle
|
6 |
+
from pointllm.utils import disable_torch_init
|
7 |
+
from pointllm.model import *
|
8 |
+
from pointllm.model.utils import KeywordsStoppingCriteria
|
9 |
+
|
10 |
+
from pointllm.data import load_objaverse_point_cloud
|
11 |
+
|
12 |
+
import os
|
13 |
+
|
14 |
+
def load_point_cloud(args):
|
15 |
+
object_id = args.object_id
|
16 |
+
print(f"[INFO] Loading point clouds using object_id: {object_id}")
|
17 |
+
point_cloud = load_objaverse_point_cloud(args.data_path, object_id, pointnum=8192, use_color=True)
|
18 |
+
|
19 |
+
return object_id, torch.from_numpy(point_cloud).unsqueeze_(0).to(torch.float32)
|
20 |
+
|
21 |
+
def init_model(args):
|
22 |
+
# Model
|
23 |
+
disable_torch_init()
|
24 |
+
|
25 |
+
model_path = args.model_path
|
26 |
+
print(f'[INFO] Model name: {model_path}')
|
27 |
+
|
28 |
+
tokenizer = AutoTokenizer.from_pretrained(model_path)
|
29 |
+
model = PointLLMLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=False, use_cache=True, torch_dtype=args.torch_dtype).cuda()
|
30 |
+
model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer)
|
31 |
+
|
32 |
+
model.eval()
|
33 |
+
|
34 |
+
mm_use_point_start_end = getattr(model.config, "mm_use_point_start_end", False)
|
35 |
+
# Add special tokens ind to model.point_config
|
36 |
+
point_backbone_config = model.get_model().point_backbone_config
|
37 |
+
|
38 |
+
if mm_use_point_start_end:
|
39 |
+
if "v1" in model_path.lower():
|
40 |
+
conv_mode = "vicuna_v1_1"
|
41 |
+
else:
|
42 |
+
raise NotImplementedError
|
43 |
+
|
44 |
+
conv = conv_templates[conv_mode].copy()
|
45 |
+
|
46 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
47 |
+
keywords = [stop_str]
|
48 |
+
|
49 |
+
return model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv
|
50 |
+
|
51 |
+
def start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv):
|
52 |
+
point_token_len = point_backbone_config['point_token_len']
|
53 |
+
default_point_patch_token = point_backbone_config['default_point_patch_token']
|
54 |
+
default_point_start_token = point_backbone_config['default_point_start_token']
|
55 |
+
default_point_end_token = point_backbone_config['default_point_end_token']
|
56 |
+
# The while loop will keep running until the user decides to quit
|
57 |
+
print("[INFO] Starting conversation... Enter 'q' to exit the program and enter 'exit' to exit the current conversation.")
|
58 |
+
while True:
|
59 |
+
print("-" * 80)
|
60 |
+
# Prompt for object_id
|
61 |
+
object_id = input("[INFO] Please enter the object_id or 'q' to quit: ")
|
62 |
+
|
63 |
+
# Check if the user wants to quit
|
64 |
+
if object_id.lower() == 'q':
|
65 |
+
print("[INFO] Quitting...")
|
66 |
+
break
|
67 |
+
else:
|
68 |
+
# print info
|
69 |
+
print(f"[INFO] Chatting with object_id: {object_id}.")
|
70 |
+
|
71 |
+
# Update args with new object_id
|
72 |
+
args.object_id = object_id.strip()
|
73 |
+
|
74 |
+
# Load the point cloud data
|
75 |
+
try:
|
76 |
+
id, point_clouds = load_point_cloud(args)
|
77 |
+
except Exception as e:
|
78 |
+
print(f"[ERROR] {e}")
|
79 |
+
continue
|
80 |
+
point_clouds = point_clouds.cuda().to(args.torch_dtype)
|
81 |
+
|
82 |
+
# Reset the conversation template
|
83 |
+
conv.reset()
|
84 |
+
|
85 |
+
print("-" * 80)
|
86 |
+
|
87 |
+
# Start a loop for multiple rounds of dialogue
|
88 |
+
for i in range(100):
|
89 |
+
# This if-else block ensures the initial question from the user is included in the conversation
|
90 |
+
qs = input(conv.roles[0] + ': ')
|
91 |
+
if qs == 'exit':
|
92 |
+
break
|
93 |
+
|
94 |
+
if i == 0:
|
95 |
+
if mm_use_point_start_end:
|
96 |
+
qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs
|
97 |
+
else:
|
98 |
+
qs = default_point_patch_token * point_token_len + '\n' + qs
|
99 |
+
|
100 |
+
# Append the new message to the conversation history
|
101 |
+
conv.append_message(conv.roles[0], qs)
|
102 |
+
conv.append_message(conv.roles[1], None)
|
103 |
+
prompt = conv.get_prompt()
|
104 |
+
inputs = tokenizer([prompt])
|
105 |
+
|
106 |
+
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
107 |
+
|
108 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
109 |
+
stop_str = keywords[0]
|
110 |
+
|
111 |
+
with torch.inference_mode():
|
112 |
+
output_ids = model.generate(
|
113 |
+
input_ids,
|
114 |
+
point_clouds=point_clouds,
|
115 |
+
do_sample=True,
|
116 |
+
temperature=1.0,
|
117 |
+
top_k=50,
|
118 |
+
max_length=2048,
|
119 |
+
top_p=0.95,
|
120 |
+
stopping_criteria=[stopping_criteria])
|
121 |
+
|
122 |
+
input_token_len = input_ids.shape[1]
|
123 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
124 |
+
if n_diff_input_output > 0:
|
125 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
126 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
127 |
+
outputs = outputs.strip()
|
128 |
+
if outputs.endswith(stop_str):
|
129 |
+
outputs = outputs[:-len(stop_str)]
|
130 |
+
outputs = outputs.strip()
|
131 |
+
|
132 |
+
# Append the model's response to the conversation history
|
133 |
+
conv.pop_last_none_message()
|
134 |
+
conv.append_message(conv.roles[1], outputs)
|
135 |
+
print(f'{conv.roles[1]}: {outputs}\n')
|
136 |
+
|
137 |
+
if __name__ == "__main__":
|
138 |
+
parser = argparse.ArgumentParser()
|
139 |
+
parser.add_argument("--model_name", type=str, \
|
140 |
+
default="RunsenXu/PointLLM_7B_v1.2")
|
141 |
+
|
142 |
+
parser.add_argument("--data_path", type=str, default="data/objaverse_data")
|
143 |
+
parser.add_argument("--torch_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"])
|
144 |
+
|
145 |
+
args = parser.parse_args()
|
146 |
+
|
147 |
+
dtype_mapping = {
|
148 |
+
"float32": torch.float32,
|
149 |
+
"float16": torch.float16,
|
150 |
+
"bfloat16": torch.bfloat16,
|
151 |
+
}
|
152 |
+
|
153 |
+
args.torch_dtype = dtype_mapping[args.torch_dtype]
|
154 |
+
|
155 |
+
model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args)
|
156 |
+
|
157 |
+
start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv)
|
ThirdParty/PointLLM/pointllm/eval/chat_gradio.py
ADDED
@@ -0,0 +1,394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import argparse
|
2 |
+
from transformers import AutoTokenizer
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from pointllm.conversation import conv_templates, SeparatorStyle
|
6 |
+
from pointllm.utils import disable_torch_init
|
7 |
+
from pointllm.model import *
|
8 |
+
from pointllm.model.utils import KeywordsStoppingCriteria
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
from pointllm.data import pc_norm, farthest_point_sample
|
12 |
+
|
13 |
+
import os
|
14 |
+
|
15 |
+
# Additional import for gradio
|
16 |
+
import gradio as gr
|
17 |
+
import open3d as o3d
|
18 |
+
import plotly.graph_objects as go
|
19 |
+
import objaverse
|
20 |
+
import time
|
21 |
+
|
22 |
+
import logging
|
23 |
+
|
24 |
+
|
25 |
+
def change_input_method(input_method):
|
26 |
+
if input_method == 'File':
|
27 |
+
result = [gr.update(visible=True),
|
28 |
+
gr.update(visible=False)]
|
29 |
+
elif input_method == 'Object ID':
|
30 |
+
result = [gr.update(visible=False),
|
31 |
+
gr.update(visible=True)]
|
32 |
+
return result
|
33 |
+
|
34 |
+
def init_model(args):
|
35 |
+
# Model
|
36 |
+
disable_torch_init()
|
37 |
+
model_name = os.path.expanduser(args.model_name)
|
38 |
+
|
39 |
+
# * print the model_name (get the basename)
|
40 |
+
print(f'[INFO] Model name: {os.path.basename(model_name)}')
|
41 |
+
logging.warning(f'Model name: {os.path.basename(model_name)}')
|
42 |
+
|
43 |
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
44 |
+
model = PointLLMLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=True).cuda()
|
45 |
+
model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer)
|
46 |
+
|
47 |
+
model.eval()
|
48 |
+
|
49 |
+
mm_use_point_start_end = getattr(model.config, "mm_use_point_start_end", False)
|
50 |
+
# Add special tokens ind to model.point_config
|
51 |
+
point_backbone_config = model.get_model().point_backbone_config
|
52 |
+
|
53 |
+
conv = conv_templates["vicuna_v1_1"].copy()
|
54 |
+
|
55 |
+
stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
|
56 |
+
keywords = [stop_str]
|
57 |
+
|
58 |
+
return model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv
|
59 |
+
|
60 |
+
def start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv):
|
61 |
+
point_token_len = point_backbone_config['point_token_len']
|
62 |
+
default_point_patch_token = point_backbone_config['default_point_patch_token']
|
63 |
+
default_point_start_token = point_backbone_config['default_point_start_token']
|
64 |
+
default_point_end_token = point_backbone_config['default_point_end_token']
|
65 |
+
|
66 |
+
# The while loop will keep running until the user decides to quit
|
67 |
+
print("[INFO] Starting conversation...")
|
68 |
+
logging.warning("Starting conversation...")
|
69 |
+
while True:
|
70 |
+
print("-" * 80)
|
71 |
+
logging.warning("-" * 80)
|
72 |
+
|
73 |
+
# Reset the conversation template
|
74 |
+
conv.reset()
|
75 |
+
|
76 |
+
def confirm_point_cloud(input_choice, object_id_input, point_cloud_input, chatbot, answer_time, conv):
|
77 |
+
objects = None
|
78 |
+
data = None
|
79 |
+
object_id_input = object_id_input.strip()
|
80 |
+
|
81 |
+
print("%" * 80)
|
82 |
+
logging.warning("%" * 80)
|
83 |
+
|
84 |
+
if input_choice == 'File':
|
85 |
+
file = point_cloud_input.name
|
86 |
+
print(f"Uploading file: {file}.")
|
87 |
+
logging.warning(f"Uploading file: {file}.")
|
88 |
+
elif input_choice == 'Object ID':
|
89 |
+
file = os.path.join(args.data_path, "{}_8192.npy".format(object_id_input))
|
90 |
+
print(f"Object_id: {object_id_input}")
|
91 |
+
logging.warning(f"Object_id: {object_id_input}")
|
92 |
+
|
93 |
+
object_uids = [object_id_input]
|
94 |
+
objects = objaverse.load_objects(uids=object_uids)
|
95 |
+
print("%" * 80)
|
96 |
+
logging.warning("%" * 80)
|
97 |
+
|
98 |
+
manual_no_color = "no_color" in file
|
99 |
+
|
100 |
+
try:
|
101 |
+
if '.ply' in file:
|
102 |
+
pcd = o3d.io.read_point_cloud(file)
|
103 |
+
points = np.asarray(pcd.points) # xyz
|
104 |
+
colors = np.asarray(pcd.colors) # rgb, if available
|
105 |
+
# * if no colors actually, empty array
|
106 |
+
if colors.size == 0:
|
107 |
+
colors = None
|
108 |
+
elif '.npy' in file:
|
109 |
+
data = np.load(file)
|
110 |
+
if data.shape[1] >= 3:
|
111 |
+
points = data[:, :3]
|
112 |
+
else:
|
113 |
+
raise ValueError("Input array has the wrong shape. Expected: [N, 3]. Got: {}.".format(data.shape))
|
114 |
+
colors = None if data.shape[1] < 6 else data[:, 3:6]
|
115 |
+
else:
|
116 |
+
raise ValueError("Not supported data format.")
|
117 |
+
# error
|
118 |
+
except Exception as e:
|
119 |
+
print(f"[ERROR] {e}")
|
120 |
+
logging.warning(f"[ERROR] {e}")
|
121 |
+
|
122 |
+
chatbot_system_message = "Sorry. The Objaverse id is not supported or the uploaded file has something wrong!"
|
123 |
+
print(f"[ChatBot System Message]: {chatbot_system_message}")
|
124 |
+
logging.warning(f"[ChatBot System Message]: {chatbot_system_message}")
|
125 |
+
|
126 |
+
outputs = f"<span style='color: red;'>[System] {chatbot_system_message}</span>" # "You upload a new Points Cloud"
|
127 |
+
chatbot = chatbot + [[None, outputs]]
|
128 |
+
|
129 |
+
return None, None, chatbot, answer_time, None
|
130 |
+
|
131 |
+
if manual_no_color:
|
132 |
+
colors = None
|
133 |
+
|
134 |
+
if colors is not None:
|
135 |
+
# * if colors in range(0-1)
|
136 |
+
if np.max(colors) <= 1:
|
137 |
+
color_data = np.multiply(colors, 255).astype(int) # Convert float values (0-1) to integers (0-255)
|
138 |
+
# * if colors in range(0-255)
|
139 |
+
elif np.max(colors) <= 255:
|
140 |
+
color_data = colors.astype(int)
|
141 |
+
else:
|
142 |
+
color_data = np.zeros_like(points).astype(int) # Default to black color if RGB information is not available
|
143 |
+
colors = color_data.astype(np.float32) / 255 # model input is (0-1)
|
144 |
+
|
145 |
+
# Convert the RGB color data to a list of RGB strings in the format 'rgb(r, g, b)'
|
146 |
+
color_strings = ['rgb({},{},{})'.format(r, g, b) for r, g, b in color_data]
|
147 |
+
|
148 |
+
fig = go.Figure(
|
149 |
+
data=[
|
150 |
+
go.Scatter3d(
|
151 |
+
x=points[:, 0], y=points[:, 1], z=points[:, 2],
|
152 |
+
mode='markers',
|
153 |
+
marker=dict(
|
154 |
+
size=1.2,
|
155 |
+
color=color_strings, # Use the list of RGB strings for the marker colors
|
156 |
+
)
|
157 |
+
)
|
158 |
+
],
|
159 |
+
layout=dict(
|
160 |
+
scene=dict(
|
161 |
+
xaxis=dict(visible=False),
|
162 |
+
yaxis=dict(visible=False),
|
163 |
+
zaxis=dict(visible=False)
|
164 |
+
),
|
165 |
+
paper_bgcolor='rgb(255,255,255)' # Set the background color to dark gray 50, 50, 50
|
166 |
+
),
|
167 |
+
)
|
168 |
+
|
169 |
+
points = np.concatenate((points, colors), axis=1)
|
170 |
+
if 8192 < points.shape[0]:
|
171 |
+
points = farthest_point_sample(points, 8192)
|
172 |
+
point_clouds = pc_norm(points)
|
173 |
+
point_clouds = torch.from_numpy(point_clouds).unsqueeze_(0).to(torch.float32).cuda()
|
174 |
+
|
175 |
+
answer_time = 0
|
176 |
+
conv.reset()
|
177 |
+
|
178 |
+
outputs = "<span style='color: red;'>[System] New Point Cloud</span>"
|
179 |
+
chatbot = chatbot + [[None, outputs]]
|
180 |
+
|
181 |
+
return fig, list(objects.values())[0] if objects is not None else None, chatbot, answer_time, point_clouds
|
182 |
+
|
183 |
+
def answer_generate(history, answer_time, point_clouds, conv):
|
184 |
+
if point_clouds is None:
|
185 |
+
outputs = "<span style='color: red;'>[System] Please input point cloud! </span>"
|
186 |
+
history[-1][1] = outputs
|
187 |
+
yield history
|
188 |
+
else:
|
189 |
+
print(f"Answer Time: {answer_time}")
|
190 |
+
logging.warning(f"Answer Time: {answer_time}")
|
191 |
+
input_text = history[-1][0]
|
192 |
+
qs = input_text
|
193 |
+
|
194 |
+
if answer_time == 0:
|
195 |
+
if mm_use_point_start_end:
|
196 |
+
qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs
|
197 |
+
else:
|
198 |
+
qs = default_point_patch_token * point_token_len + '\n' + qs
|
199 |
+
|
200 |
+
# Append the new message to the conversation history
|
201 |
+
conv.append_message(conv.roles[0], qs)
|
202 |
+
conv.append_message(conv.roles[1], None)
|
203 |
+
prompt = conv.get_prompt()
|
204 |
+
print("#" * 80)
|
205 |
+
print(f'{prompt.replace("<point_patch>" * point_token_len, f"<point_patch> * {point_token_len}")}') # for concise printing
|
206 |
+
print("#" * 80)
|
207 |
+
|
208 |
+
logging.warning("#" * 80)
|
209 |
+
logging.warning(f'{prompt.replace("<point_patch>" * point_token_len, f"<point_patch> * {point_token_len}")}') # for concise printing
|
210 |
+
logging.warning("#" * 80)
|
211 |
+
inputs = tokenizer([prompt])
|
212 |
+
|
213 |
+
input_ids = torch.as_tensor(inputs.input_ids).cuda()
|
214 |
+
|
215 |
+
stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
|
216 |
+
stop_str = keywords[0]
|
217 |
+
|
218 |
+
try:
|
219 |
+
if input_ids.shape[1] >= 2047:
|
220 |
+
raise ValueError("Current context length exceeds the maximum context length (2048) of the model.")
|
221 |
+
with torch.inference_mode():
|
222 |
+
output_ids = model.generate(
|
223 |
+
input_ids,
|
224 |
+
point_clouds=point_clouds,
|
225 |
+
do_sample=True,
|
226 |
+
temperature=1.0,
|
227 |
+
top_k=50,
|
228 |
+
max_length=2048,
|
229 |
+
top_p=0.95,
|
230 |
+
stopping_criteria=[stopping_criteria])
|
231 |
+
|
232 |
+
input_token_len = input_ids.shape[1]
|
233 |
+
n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
|
234 |
+
if n_diff_input_output > 0:
|
235 |
+
print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
|
236 |
+
logging.warning(f'{n_diff_input_output} output_ids are not the same as the input_ids')
|
237 |
+
outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
|
238 |
+
outputs = outputs.strip()
|
239 |
+
if outputs.endswith(stop_str):
|
240 |
+
outputs = outputs[:-len(stop_str)]
|
241 |
+
outputs = outputs.strip()
|
242 |
+
|
243 |
+
# Append the model's response to the conversation history
|
244 |
+
conv.pop_last_none_message()
|
245 |
+
conv.append_message(conv.roles[1], outputs)
|
246 |
+
print(f'{conv.roles[1]}: {outputs}\n')
|
247 |
+
logging.warning(f'{conv.roles[1]}: {outputs}\n')
|
248 |
+
answer_time += 1
|
249 |
+
history[-1][1] = ""
|
250 |
+
for character in outputs:
|
251 |
+
history[-1][1] += character
|
252 |
+
yield history
|
253 |
+
# error
|
254 |
+
except Exception as e:
|
255 |
+
print(f"[ERROR] {e}")
|
256 |
+
logging.warning(f"[ERROR] {e}")
|
257 |
+
|
258 |
+
if input_ids.shape[1] >= 2047:
|
259 |
+
chatbot_system_message = "Current context length exceeds the maximum context length (2048) of the model. Please press 'Clear' to restart."
|
260 |
+
else:
|
261 |
+
chatbot_system_message = "Sorry. There is something wrong when generating. Please check the your uploaded point cloud or the Objaverse id, and \
|
262 |
+
confirm the point cloud again."
|
263 |
+
print(f"[ChatBot System Message]: {chatbot_system_message}")
|
264 |
+
logging.warning(f"[ChatBot System Message]: {chatbot_system_message}")
|
265 |
+
|
266 |
+
outputs = f"<span style='color: red;'>[System] {chatbot_system_message}</span>" # "You upload a new Points Cloud"
|
267 |
+
history[-1][1] = outputs
|
268 |
+
yield history
|
269 |
+
|
270 |
+
with gr.Blocks() as demo:
|
271 |
+
answer_time = gr.State(value=0)
|
272 |
+
point_clouds = gr.State(value=None)
|
273 |
+
conv_state = gr.State(value=conv.copy())
|
274 |
+
gr.Markdown(
|
275 |
+
"""
|
276 |
+
# PointLLM: Empowering Large Language Models to Understand Point Clouds. 🚀
|
277 |
+
If you think this demo interesting, please consider starring 🌟 our github repo. :)
|
278 |
+
[[Project Page](https://runsenxu.com/projects/PointLLM)] [[Paper](https://arxiv.org/abs/2308.16911)] [[Code](https://github.com/OpenRobotLab/PointLLM)]
|
279 |
+
"""
|
280 |
+
)
|
281 |
+
with gr.Row():
|
282 |
+
with gr.Column():
|
283 |
+
input_choice = gr.Radio(['File', 'Object ID'], value='Object ID', interactive=True, label='Input Method', info="How do you want to load point clouds?")
|
284 |
+
object_id_input = gr.Textbox(visible = True,lines=1, label='Object ID Input')
|
285 |
+
point_cloud_input = gr.File(visible = False, label="Upload Point Cloud File (PLY, NPY)")
|
286 |
+
output = gr.Plot()
|
287 |
+
btn = gr.Button(value="Confirm Point Cloud")
|
288 |
+
model3D = gr.Model3D()
|
289 |
+
with gr.Column():
|
290 |
+
chatbot = gr.Chatbot([], elem_id="chatbot", height=560) # ,color_map=("green", "pink")
|
291 |
+
|
292 |
+
def user(user_message, history):
|
293 |
+
return "", history + [[user_message, None]]
|
294 |
+
|
295 |
+
def clear_conv(history, conv):
|
296 |
+
conv.reset()
|
297 |
+
return None, 0
|
298 |
+
|
299 |
+
with gr.Row():
|
300 |
+
text_input = gr.Textbox(
|
301 |
+
show_label=False,
|
302 |
+
placeholder="Enter text and press enter",
|
303 |
+
container=False,
|
304 |
+
)
|
305 |
+
run_button = gr.Button("Send")
|
306 |
+
|
307 |
+
clear = gr.Button("Clear")
|
308 |
+
text_input.submit(user, [text_input, chatbot], [text_input, chatbot], queue=False).then(answer_generate, [chatbot, answer_time, point_clouds, conv_state], chatbot).then(lambda x : x+1,answer_time, answer_time)
|
309 |
+
clear.click(clear_conv, inputs=[chatbot, conv_state], outputs=[chatbot, answer_time], queue=False)
|
310 |
+
|
311 |
+
btn.click(confirm_point_cloud, inputs=[input_choice, object_id_input, point_cloud_input, chatbot, answer_time, conv_state], outputs=[output, model3D, chatbot, answer_time, point_clouds])
|
312 |
+
|
313 |
+
input_choice.change(change_input_method, input_choice, [point_cloud_input, object_id_input])
|
314 |
+
run_button.click(user, [text_input, chatbot], [text_input, chatbot], queue=False).then(answer_generate, [chatbot, answer_time, point_clouds, conv_state], chatbot).then(lambda x : x+1, answer_time, answer_time)
|
315 |
+
|
316 |
+
gr.Markdown(
|
317 |
+
"""
|
318 |
+
### Usage:
|
319 |
+
1. Upload your point cloud file (ply, npy only) or input the supported [Objaverse object id (uid)](https://drive.google.com/file/d/1gLwA7aHfy1KCrGeXlhICG9rT2387tWY8/view?usp=sharing) (currently 660K objects only, you may try the example object ids below).
|
320 |
+
2. If your point cloud file does not contian colors, manually set the file name contains 'no_color' (e.g., 'xxx_no_color.npy'), and the black color will be assigned.
|
321 |
+
3. If uploading your own point cloud file with color in npy format, the first three dimensions should be xyz, and the next three dimensions should be rgb. The rgb values should range from **0 to 1**.
|
322 |
+
4. Click **Confirm Point Cloud**.
|
323 |
+
5. As we use FPS sampling to downsample the point cloud to 8192 points, it may take a long time to confirm the point cloud if the point cloud has too many points. You may use random sampling to downsample the point cloud before uploading.
|
324 |
+
6. Once '[System] New Point Cloud' appears in the dialogue box, a new conversation with PointLLM is initialized.
|
325 |
+
7. The 'Clear' button will clear the conversation history.
|
326 |
+
""")
|
327 |
+
with gr.Accordion("Example Objaverse object ids in the validation set!", open=False):
|
328 |
+
example_object_ids = [ ["b4bbf2116b1a41a5a3b9d3622b07074c", "0b8da82a3d7a436f9b585436c4b72f56", "650c53d68d374c18886aab91bcf8bb54"],
|
329 |
+
["983fa8b23a084f5dacd157e6c9ceba97", "8fe23dd4bf8542b49c3a574b33e377c3", "83cb2a9e9afb47cd9f45461613796645"],
|
330 |
+
["3d679a3888c548afb8cf889915af7fd2", "7bcf8626eaca40e592ffd0aed08aa30b", "69865c89fc7344be8ed5c1a54dbddc20"],
|
331 |
+
["252f3b3f5cd64698826fc1ab42614677", "e85ebb729b02402bbe3b917e1196f8d3", "97367c4740f64935b7a5e34ae1398035"],
|
332 |
+
["fc8dd5a2fc9f4dd19ad6a64a8a6e89e9", "8257772b0e2f408ba269264855dfea00", "d6a3520486bb474f9b5e72eda8408974"],
|
333 |
+
["3d10918e6a9a4ad395a7280c022ad2b9", "00002bcb84af4a4781174e62619f14e2", "76ba80230d454de996878c2763fe7e5c"]]
|
334 |
+
gr.DataFrame(
|
335 |
+
type="array",
|
336 |
+
headers=["Example Object IDs"] * 3,
|
337 |
+
row_count=6,
|
338 |
+
col_count=3,
|
339 |
+
value=example_object_ids
|
340 |
+
)
|
341 |
+
gr.Markdown(
|
342 |
+
"""
|
343 |
+
#### Terms of use
|
344 |
+
By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
|
345 |
+
"""
|
346 |
+
)
|
347 |
+
demo.queue()
|
348 |
+
demo.launch(server_name="0.0.0.0", server_port=args.port, share=False) # server_port=7832, share=True
|
349 |
+
|
350 |
+
if __name__ == "__main__":
|
351 |
+
# ! To release this demo in public, make sure to start in a place where no important data is stored.
|
352 |
+
# ! Please check 1. the lanuch dir 2. the tmp dir (GRADIO_TEMP_DIR)
|
353 |
+
# ! refer to https://www.gradio.app/guides/sharing-your-app#security-and-file-access
|
354 |
+
parser = argparse.ArgumentParser()
|
355 |
+
parser.add_argument("--model-name", type=str, \
|
356 |
+
default="RunsenXu/PointLLM_7B_v1.2")
|
357 |
+
|
358 |
+
|
359 |
+
parser.add_argument("--data_path", type=str, default="data/objaverse_data", required=False)
|
360 |
+
parser.add_argument("--pointnum", type=int, default=8192)
|
361 |
+
|
362 |
+
parser.add_argument("--log_file", type=str, default="serving_workdirs/serving_log.txt")
|
363 |
+
parser.add_argument("--tmp_dir", type=str, default="serving_workdirs/tmp")
|
364 |
+
|
365 |
+
# For gradio
|
366 |
+
parser.add_argument("--port", type=int, default=7810)
|
367 |
+
|
368 |
+
args = parser.parse_args()
|
369 |
+
|
370 |
+
# * make serving dirs
|
371 |
+
os.makedirs(os.path.dirname(args.log_file), exist_ok=True)
|
372 |
+
os.makedirs(args.tmp_dir, exist_ok=True)
|
373 |
+
|
374 |
+
# * add the current time for log name
|
375 |
+
args.log_file = args.log_file.replace(".txt", f"_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.txt")
|
376 |
+
|
377 |
+
logging.basicConfig(
|
378 |
+
filename=args.log_file,
|
379 |
+
level=logging.WARNING, # * default gradio is info, so use warning
|
380 |
+
format='%(asctime)s - %(message)s',
|
381 |
+
datefmt='%Y-%m-%d %H:%M:%S'
|
382 |
+
)
|
383 |
+
|
384 |
+
logging.warning("-----New Run-----")
|
385 |
+
logging.warning(f"args: {args}")
|
386 |
+
|
387 |
+
print("-----New Run-----")
|
388 |
+
print(f"[INFO] Args: {args}")
|
389 |
+
|
390 |
+
# * set env variable GRADIO_TEMP_DIR to args.tmp_dir
|
391 |
+
os.environ["GRADIO_TEMP_DIR"] = args.tmp_dir
|
392 |
+
|
393 |
+
model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args)
|
394 |
+
start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv)
|