yfdeng commited on
Commit
744eb4e
·
1 Parent(s): 5df226f
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Anymate/.gitignore +26 -0
  2. Anymate/__init__.py +0 -0
  3. Anymate/args.py +22 -0
  4. Anymate/blender_script.py +747 -0
  5. Anymate/checkpoints/.gitkeep +0 -0
  6. Anymate/configs/.gitkeep +0 -0
  7. Anymate/configs/conn.yaml +40 -0
  8. Anymate/configs/conn_token.yaml +40 -0
  9. Anymate/configs/diffusion.yaml +49 -0
  10. Anymate/configs/diffusion_concat.yaml +46 -0
  11. Anymate/configs/diffusion_cross.yaml +51 -0
  12. Anymate/configs/joints.yaml +40 -0
  13. Anymate/configs/joints_implicit.yaml +40 -0
  14. Anymate/configs/joints_triplane.yaml +40 -0
  15. Anymate/configs/skin.yaml +40 -0
  16. Anymate/configs/skin_multi.yaml +40 -0
  17. Anymate/dataset.py +62 -0
  18. Anymate/get_checkpoints.sh +22 -0
  19. Anymate/get_datasets.sh +12 -0
  20. Anymate/model.py +360 -0
  21. Anymate/models/__init__.py +0 -0
  22. Anymate/models/conn.py +195 -0
  23. Anymate/models/diffusion.py +483 -0
  24. Anymate/models/joint.py +282 -0
  25. Anymate/models/skin.py +309 -0
  26. Anymate/tmp/.gitkeep +0 -0
  27. Anymate/utils/dataset_utils.py +129 -0
  28. Anymate/utils/diffusion_encoder.py +258 -0
  29. Anymate/utils/diffusion_utils.py +314 -0
  30. Anymate/utils/eval_utils.py +225 -0
  31. Anymate/utils/loss_utils.py +56 -0
  32. Anymate/utils/render_utils.py +1169 -0
  33. Anymate/utils/train_utils.py +406 -0
  34. Anymate/utils/ui_utils.py +284 -0
  35. Anymate/utils/ui_utils_bpy.py +134 -0
  36. Anymate/utils/utils.py +77 -0
  37. Anymate/utils/vol_utils.py +135 -0
  38. Render.py +17 -0
  39. ThirdParty/PointLLM/.gitignore +12 -0
  40. ThirdParty/PointLLM/README.md +353 -0
  41. ThirdParty/PointLLM/__init__.py +0 -0
  42. ThirdParty/PointLLM/pointllm/__init__.py +1 -0
  43. ThirdParty/PointLLM/pointllm/conversation.py +375 -0
  44. ThirdParty/PointLLM/pointllm/data/__init__.py +3 -0
  45. ThirdParty/PointLLM/pointllm/data/modelnet.py +147 -0
  46. ThirdParty/PointLLM/pointllm/data/modelnet_config/ModelNet40.yaml +8 -0
  47. ThirdParty/PointLLM/pointllm/data/object_point_dataset.py +250 -0
  48. ThirdParty/PointLLM/pointllm/data/utils.py +236 -0
  49. ThirdParty/PointLLM/pointllm/eval/PointLLM_chat.py +157 -0
  50. ThirdParty/PointLLM/pointllm/eval/chat_gradio.py +394 -0
Anymate/.gitignore ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.pt
3
+ *.tar
4
+ *.tar
5
+ *.txt
6
+ *.glb*
7
+ *.obj
8
+ *.ckpt
9
+ *.blend
10
+ *.blend1
11
+ test_*
12
+
13
+ blender-*
14
+ *.json*
15
+ *.glb
16
+ *.gltf
17
+ *.fbx
18
+ *.FBX
19
+ *.dae
20
+ *.obj
21
+ *.mtl
22
+ *.binvox
23
+ *.csv
24
+ *.tga
25
+ *.png
26
+ *.jpg
Anymate/__init__.py ADDED
File without changes
Anymate/args.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class AnymateArgs:
2
+ def __init__(self):
3
+ # self.encoder = "miche"
4
+ # self.decoder = "transformer_latent"
5
+ # self.dataset = "train"
6
+ # self.run_name = "miche-transformer_latent-train-8gpu-finetune"
7
+ self.checkpoint_joint = "Anymate/checkpoints/joint/bert-transformer_latent-train-8gpu-finetune.pth.tar"
8
+ self.checkpoint_conn = "Anymate/checkpoints/conn/bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar"
9
+ self.checkpoint_skin = "Anymate/checkpoints/skin/bert-attendjoints_combine-train-8gpu-finetune.pth.tar"
10
+
11
+ self.device = "cuda"
12
+ self.num_joints = 96
13
+
14
+
15
+ class UIArgs:
16
+ def __init__(self):
17
+ self.checkpoint_joint = "Anymate/checkpoints/joint/bert-transformer_latent-train-8gpu-finetune.pth.tar"
18
+ self.checkpoint_conn = "Anymate/checkpoints/conn/bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar"
19
+ self.checkpoint_skin = "Anymate/checkpoints/skin/bert-attendjoints_combine-train-8gpu-finetune.pth.tar"
20
+
21
+ ui_args = UIArgs()
22
+ anymate_args = AnymateArgs()
Anymate/blender_script.py ADDED
@@ -0,0 +1,747 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bpy
2
+ import mathutils
3
+ from mathutils import Vector, Matrix
4
+
5
+ import os
6
+ import sys
7
+ import random
8
+ import numpy as np
9
+ import json
10
+ import argparse
11
+
12
+
13
+ IMPORT_FUNCTIONS = {
14
+ "obj": bpy.ops.wm.obj_import,
15
+ "glb": bpy.ops.import_scene.gltf,
16
+ "gltf": bpy.ops.import_scene.gltf,
17
+ "usd": bpy.ops.import_scene.usd,
18
+ "fbx": bpy.ops.import_scene.fbx,
19
+ "stl": bpy.ops.import_mesh.stl,
20
+ "usda": bpy.ops.import_scene.usda,
21
+ "dae": bpy.ops.wm.collada_import,
22
+ "ply": bpy.ops.import_mesh.ply,
23
+ "abc": bpy.ops.wm.alembic_import,
24
+ "blend": bpy.ops.wm.append,
25
+ }
26
+
27
+ def load_object(object_path: str) -> None:
28
+ """Loads a model with a supported file extension into the scene.
29
+
30
+ Args:
31
+ object_path (str): Path to the model file.
32
+
33
+ Raises:
34
+ ValueError: If the file extension is not supported.
35
+
36
+ Returns:
37
+ None
38
+ """
39
+ file_extension = object_path.split(".")[-1].lower()
40
+ if file_extension is None:
41
+ raise ValueError(f"Unsupported file type: {object_path}")
42
+
43
+ # load from existing import functions
44
+ import_function = IMPORT_FUNCTIONS[file_extension]
45
+
46
+ if file_extension == "blend":
47
+ import_function(directory=object_path, link=False)
48
+ elif file_extension in {"glb", "gltf"}:
49
+ import_function(filepath=object_path, merge_vertices=True)
50
+ else:
51
+ import_function(filepath=object_path)
52
+
53
+ ####################### save json ################################
54
+ def save_json(output_path, mesh_obj, armature_obj, extra=None, arm_name=False):
55
+ # makedirs output_path
56
+ os.makedirs(output_path, exist_ok=True)
57
+
58
+ # start retrieve the information of mesh, skining and rigging
59
+
60
+ #1. retrieve the information of rigging, save the world matrix of the amature object
61
+ total_armature_info = {}
62
+ for obj in armature_obj:
63
+ # depsgraph = bpy.context.evaluated_depsgraph_get()
64
+ # obj = obj.evaluated_get(depsgraph)
65
+ armature_info = {}
66
+ armature_info["world_matrix"] = [list(row) for row in obj.matrix_world.copy()]
67
+ translation = obj.matrix_world.translation
68
+ for bone in obj.pose.bones:
69
+ bone_info = {}
70
+ bone_info["head_local"] = list(bone.head.copy())
71
+ bone_info["head_world"] = list((obj.matrix_world.to_3x3() @ bone.head+translation).copy())
72
+ # bone_info["matrix_local"] = [list(row) for row in bone.matrix_local.copy()]
73
+ bone_info["tail_local"] = list(bone.tail.copy())
74
+ bone_info["tail_world"] = list((obj.matrix_world.to_3x3() @ bone.tail+translation).copy())
75
+
76
+ if bone.parent:
77
+ bone_info["parent"] = bone.parent.name.replace(" ", "_")
78
+ if arm_name:
79
+ bone_info["parent"] = obj.name + "--" + bone_info["parent"]
80
+ else:
81
+ bone_info["parent"] = None
82
+ bone_info["children"] = []
83
+ if bone.children:
84
+ for child in bone.children:
85
+ if arm_name:
86
+ bone_info["children"].append(obj.name + "--" + child.name.replace(" ", "_"))
87
+ else:
88
+ bone_info["children"].append(child.name.replace(" ", "_"))
89
+ bone_name = bone.name.replace(" ", "_")
90
+ if arm_name:
91
+ bone_name = obj.name + "--" + bone_name
92
+ armature_info[bone_name] = bone_info
93
+ obj_name = obj.name.replace(" ", "_")
94
+ total_armature_info[obj.name] = armature_info
95
+
96
+
97
+ #2. retrieve the informatioon of skining
98
+ total_skinning_info = {}
99
+ for obj in mesh_obj:
100
+ vertex_groups = obj.vertex_groups
101
+ # if not vertex_groups:
102
+ # continue
103
+ # for group in vertex_groups:
104
+ skinning_info = {}
105
+ skinning_info["world_matrix"] = [list(row) for row in obj.matrix_world.copy()]
106
+ weight_info = []
107
+ for vertex in obj.data.vertices:
108
+ vertex_info = {}
109
+ for group in vertex.groups:
110
+ name = vertex_groups[group.group].name
111
+ name = name.replace(" ", "_")
112
+ if arm_name:
113
+ arm_modifier = [modifier for modifier in obj.modifiers if modifier.type == 'ARMATURE']
114
+ assert(len(arm_modifier) == 1)
115
+ name = arm_modifier[0].object.name + "--" + name
116
+ weight = group.weight
117
+ vertex_info[name] = weight
118
+ weight_info.append(vertex_info)
119
+ skinning_info["weight"] = weight_info
120
+ obj_name = obj.name.replace(" ", "_")
121
+ total_skinning_info[obj_name]=skinning_info
122
+
123
+
124
+ rigging_file_path = os.path.join(output_path, "rigging.json")
125
+ if extra:
126
+ rigging_file_path = rigging_file_path.replace("rigging.json", f'rigging_{extra}.json')
127
+ with open(rigging_file_path, "w") as f:
128
+ json.dump(total_armature_info, f, indent = 2)
129
+
130
+ skining_file_path = os.path.join(output_path, "skining.json")
131
+ if extra:
132
+ skining_file_path = skining_file_path.replace("skining.json", f'skining_{extra}.json')
133
+ with open(skining_file_path, "w") as f:
134
+ json.dump(total_skinning_info, f , indent = 2)
135
+
136
+
137
+ return rigging_file_path
138
+
139
+
140
+ def apply_skinning_weights(json_file):
141
+
142
+ with open(json_file, "r") as f:
143
+ skinning_data = json.load(f)
144
+
145
+ armature_obj = bpy.data.objects.get("Armature")
146
+ if not armature_obj:
147
+ print("Error: Armature object 'Armature' not found.")
148
+ return
149
+
150
+ # 将所有网格对象放置在骨骼对象的子集中
151
+ count = 0
152
+ for obj in bpy.context.scene.objects:
153
+ if obj.type == 'MESH':
154
+ obj.parent = armature_obj
155
+ count += 1
156
+
157
+ print("total mesh count:", count)
158
+
159
+ for obj in bpy.context.scene.objects:
160
+ vertex_index = 0
161
+ if obj.type == 'MESH':
162
+ mesh_name = obj.name
163
+ if mesh_name in skinning_data:
164
+ skinning_info = skinning_data[mesh_name]
165
+ if "weight" in skinning_info:
166
+ print("Applying skinning data for mesh:", mesh_name)
167
+ vertex_index = 0
168
+ for vertex_weight in skinning_info["weight"]:
169
+ for bone_name, weight_value in vertex_weight.items():
170
+ vertex_group = obj.vertex_groups.get(bone_name)
171
+ if vertex_group is None:
172
+ vertex_group = obj.vertex_groups.new(name=bone_name)
173
+ print("Vertex group created:", bone_name)
174
+ vertex_group.add([vertex_index], weight_value, 'REPLACE')
175
+ vertex_index += 1
176
+ else:
177
+ print("No skinning data found for mesh:", mesh_name)
178
+ for obj in bpy.context.scene.objects:
179
+ if obj.type == 'MESH':
180
+ modifier = obj.modifiers.new(name="Armature", type='ARMATURE')
181
+ modifier.object = armature_obj
182
+ modifier.use_vertex_groups = True
183
+ print("Armature modifier added to mesh:", obj.name)
184
+
185
+ def reload_rigging(rigging_file_path):
186
+ with open(rigging_file_path, "r") as f:
187
+ total_armature_info = json.load(f)
188
+
189
+ bpy.ops.object.armature_add()
190
+ armature_obj = bpy.context.object
191
+ armature_obj.name = "Armature"
192
+
193
+ bpy.ops.object.mode_set(mode='EDIT')
194
+ bpy.ops.armature.select_all(action='SELECT')
195
+ bpy.ops.armature.delete()
196
+ bpy.ops.object.mode_set(mode='OBJECT')
197
+ bpy.ops.object.mode_set(mode='EDIT')
198
+
199
+ world_matrix = mathutils.Matrix([[1, 0, 0, 0],
200
+ [0, 1, 0, 0],
201
+ [0, 0, 1, 0],
202
+ [0, 0, 0, 1]])
203
+ armature_obj.matrix_world = world_matrix
204
+
205
+ for armature_name, armature_info in total_armature_info.items():
206
+ for bone_name, bone_info in armature_info.items():
207
+ if bone_name == "world_matrix":
208
+ continue
209
+ bone = armature_obj.data.edit_bones.new(bone_name)
210
+ bone.head = bone_info["head_world"]
211
+ bone.tail = bone_info["tail_world"]
212
+
213
+ for bone_name, bone_info in armature_info.items():
214
+ if bone_name == "world_matrix":
215
+ continue
216
+ bone = armature_obj.data.edit_bones[bone_name]
217
+ parent_name = bone_info["parent"]
218
+ if parent_name:
219
+ parent_bone = armature_obj.data.edit_bones[parent_name]
220
+ bone.parent = parent_bone
221
+ edit_len = len(armature_obj.data.edit_bones.keys())
222
+ bpy.ops.object.mode_set(mode='OBJECT')
223
+ bone_len = len(armature_obj.data.bones.keys())
224
+ assert(edit_len == bone_len, "bone number not match!" + str(edit_len) + " " + str(bone_len))
225
+ bpy.ops.object.select_all(action='DESELECT')
226
+ armature_obj.select_set(True)
227
+ bpy.context.view_layer.objects.active = armature_obj
228
+ print("Rigging information has been reloaded!")
229
+
230
+ ############################# reload json ################################
231
+ def reload_json(folder_path, version=0, export = None):
232
+ bpy.ops.wm.read_homefile(use_empty=True)
233
+ if version == 0:
234
+ obj_path = os.path.join(folder_path, "object.obj")
235
+ skinning_file_path = os.path.join(folder_path, "skining.json")
236
+ rigging_file_path = os.path.join(folder_path, "rigging.json")
237
+ elif version == 1:
238
+ obj_path = os.path.join(folder_path, "join.obj")
239
+ skinning_file_path = os.path.join(folder_path, "skining_norig.json")
240
+ rigging_file_path = os.path.join(folder_path, "rigging_norig.json")
241
+ elif version == 2:
242
+ obj_path = os.path.join(folder_path, "join.obj")
243
+ skinning_file_path = os.path.join(folder_path, "skining_norig2.json")
244
+ rigging_file_path = os.path.join(folder_path, "rigging_norig2.json")
245
+ # import_obj(obj_path)
246
+ load_object(obj_path)
247
+ reload_rigging(rigging_file_path)
248
+ apply_skinning_weights(skinning_file_path)
249
+ if export:
250
+ bpy.ops.wm.save_as_mainfile(filepath=export)
251
+ print("Done!")
252
+
253
+
254
+ def reset_scene() -> None:
255
+ """Resets the scene to a clean state.
256
+
257
+ Returns:
258
+ None
259
+ """
260
+ # delete everything that isn't part of a camera or a light
261
+ for obj in bpy.data.objects:
262
+ if obj.type not in {"CAMERA", "LIGHT"}:
263
+ bpy.data.objects.remove(obj, do_unlink=True)
264
+
265
+ # delete all the materials
266
+ for material in bpy.data.materials:
267
+ bpy.data.materials.remove(material, do_unlink=True)
268
+
269
+ # delete all the textures
270
+ for texture in bpy.data.textures:
271
+ bpy.data.textures.remove(texture, do_unlink=True)
272
+
273
+ # delete all the images
274
+ for image in bpy.data.images:
275
+ bpy.data.images.remove(image, do_unlink=True)
276
+
277
+
278
+ def save_mesh(path, mtl=False, obj_path=None):
279
+ if mtl:
280
+ # save the blend file
281
+ bpy.ops.wm.save_as_mainfile(filepath=obj_path + '/object.blend')
282
+ # reopen the blend file
283
+ bpy.ops.wm.open_mainfile(filepath=obj_path + '/object.blend')
284
+ # unpack all the materials and textures to obj_path
285
+ bpy.ops.file.unpack_all(method='WRITE_LOCAL')
286
+ # save to .obj without material
287
+ bpy.ops.wm.obj_export(filepath=path, export_materials=mtl, export_uv=mtl, export_triangulated_mesh=True)
288
+
289
+
290
+ def get_root_obj(obj):
291
+ if not obj.parent:
292
+ return obj
293
+ return get_root_obj(obj.parent)
294
+
295
+ def normalize(objs):
296
+ # bpy.ops.object.select_all(action='DESELECT')
297
+ # # select objs and join them
298
+ # for obj in objs:
299
+ # obj.select_set(True)
300
+ # bpy.context.view_layer.objects.active = objs[0]
301
+ # name_join = objs[0].name
302
+ # bpy.ops.object.join()
303
+ # obj_join = bpy.context.active_object
304
+ # print(obj_join.matrix_world)
305
+ # print(name_join)
306
+ # assert(name_join == obj_join.name)
307
+
308
+ objs_eval = []
309
+ depsgraph = bpy.context.evaluated_depsgraph_get()
310
+ for obj in objs:
311
+ objs_eval.append(obj.evaluated_get(depsgraph))
312
+
313
+ vertices = []
314
+ for obj in objs_eval:
315
+ for v in obj.data.vertices:
316
+ vertices.append(obj.matrix_world @ Vector((v.co.x, v.co.y, v.co.z, 1)))
317
+
318
+ vertices = np.array(vertices)
319
+ min_x, min_y, min_z, _ = np.min(vertices, axis=0)
320
+ max_x, max_y, max_z, _ = np.max(vertices, axis=0)
321
+
322
+ # print(min_x, min_y, min_z)
323
+ # print(max_x, max_y, max_z)
324
+
325
+ scale_x = 1 / (max_x - min_x)
326
+ scale_y = 1 / (max_y - min_y)
327
+ scale_z = 1 / (max_z - min_z)
328
+ scale_min = min(scale_x, scale_y, scale_z)
329
+
330
+ assert scale_min < 1e6
331
+
332
+ translate_x = - (max_x + min_x) / 2 * scale_min
333
+ translate_y = - (max_y + min_y) / 2 * scale_min
334
+ translate_z = - min_z * scale_min
335
+
336
+ # form transformation matrix
337
+ trans = Matrix.Translation((translate_x, translate_y, translate_z))
338
+
339
+ scale = Matrix.Scale(scale_min, 4, (1, 0, 0)) @ Matrix.Scale(scale_min, 4, (0, 1, 0)) @ Matrix.Scale(scale_min, 4, (0, 0, 1))
340
+
341
+ # print(trans, scale)
342
+
343
+
344
+ root = get_root_obj(objs[0])
345
+ # print(root.name)
346
+ # print(root.scale)
347
+ # print(root.location)
348
+ # print(root.matrix_world)
349
+ # root.location = mathutils.Vector(root.location) + mathutils.Vector((translate_x, translate_y, translate_z))
350
+ # root.scale = mathutils.Vector(root.scale) * mathutils.Vector((scale_x, scale_y, scale_z))
351
+
352
+ # add the extra transformation to the root object's world matrix
353
+ root.matrix_world = trans @ scale @ root.matrix_world
354
+ # print(root.name)
355
+ # print(root.scale)
356
+ # print(root.location)
357
+ # print(root.matrix_world)
358
+
359
+ # refresh
360
+ bpy.context.view_layer.update()
361
+
362
+ ######### check if its successful
363
+ # objs_eval = []
364
+ # depsgraph = bpy.context.evaluated_depsgraph_get()
365
+ # for obj in objs:
366
+ # objs_eval.append(obj.evaluated_get(depsgraph))
367
+
368
+ # vertices = []
369
+ # for obj in objs_eval:
370
+ # for v in obj.data.vertices:
371
+ # vertices.append(obj.matrix_world @ Vector((v.co.x, v.co.y, v.co.z, 1)))
372
+
373
+ # vertices = np.array(vertices)
374
+ # min_x, min_y, min_z, _ = np.min(vertices, axis=0)
375
+ # max_x, max_y, max_z, _ = np.max(vertices, axis=0)
376
+
377
+ # print(min_x, min_y, min_z)
378
+ # print(max_x, max_y, max_z)
379
+
380
+ def remesh(objs, target=5000):
381
+ num_v = {}
382
+ for obj in objs:
383
+ num_v[obj] = len(obj.data.vertices)
384
+
385
+ # sort the num_v dict and make it a dict again
386
+ num_v_sort = sorted(num_v.items(), key=lambda x: x[1], reverse=True)
387
+
388
+ # print(num_v_sort)
389
+ total_v = sum([num_v[obj] for obj in num_v])
390
+
391
+ iters = 0
392
+ while total_v > target and iters<20:
393
+ reduce = []
394
+ for obj, v in num_v_sort:
395
+ reduce.append(obj)
396
+ if sum([num_v[oo] for oo in reduce]) > 0.5 * total_v:
397
+ break
398
+ for obj in reduce:
399
+ # check if have shape key
400
+ if obj.data.shape_keys is not None:
401
+ # remove obj from num_v
402
+ num_v.pop(obj)
403
+ continue
404
+
405
+ ratio = 0.5
406
+ # apply decimate modifier
407
+ bpy.context.view_layer.objects.active = obj
408
+ bpy.ops.object.modifier_add(type='DECIMATE')
409
+ bpy.context.object.modifiers["Decimate"].ratio = ratio
410
+ bpy.ops.object.modifier_apply(modifier="Decimate")
411
+ # update num_v
412
+ num_v[obj] = len(obj.data.vertices)
413
+ total_v = sum([num_v[obj] for obj in num_v])
414
+ num_v_sort = sorted(num_v.items(), key=lambda x: x[1], reverse=True)
415
+ # print(num_v_sort)
416
+ iters+=1
417
+
418
+
419
+ def get_parents(obj):
420
+ if not obj.parent:
421
+ return [obj.name]
422
+ parents = get_parents(obj.parent)
423
+ parents.append(obj.name)
424
+ return parents
425
+
426
+ def check(objs, arm):
427
+ # assert('Sketchfab_model' in bpy.data.objects)
428
+
429
+ # root_arm = get_root_obj(arm)
430
+ # for obj in objs:
431
+ # if root_arm != get_root_obj(obj):
432
+ # print('not same root')
433
+ # return -1
434
+ # return 1
435
+
436
+ # action_num = 0
437
+ # actions = bpy.data.actions
438
+ # for act in actions:
439
+ # action_num += 1
440
+ # fcurves = act.fcurves
441
+ # data_paths = []
442
+ # not_pose = False
443
+ # for fcurve in fcurves:
444
+ # data_paths.append(fcurve.data_path)
445
+ # if not fcurve.data_path.startswith('pose.bones'):
446
+ # # print(fcurve.data_path)
447
+ # not_pose = True
448
+ # # return -1
449
+ # if not_pose:
450
+ # print('zyhsb')
451
+ # print(data_paths)
452
+ # return -1
453
+ # return action_num
454
+
455
+ for obj in objs:
456
+ vertex_groups = obj.vertex_groups
457
+ # if not vertex_groups:
458
+ # continue
459
+ # for group in vertex_groups:
460
+ for vertex in obj.data.vertices:
461
+ vertex_info = {}
462
+ for group in vertex.groups:
463
+ name = vertex_groups[group.group].name
464
+ name = name.replace(" ", "_")
465
+ if True:
466
+ arm_modifier = [modifier for modifier in obj.modifiers if modifier.type == 'ARMATURE']
467
+ if len(arm_modifier) != 1:
468
+ print('zyhsb', len(arm_modifier))
469
+ return -2
470
+ # name = arm_modifier[0].object.name + "--" + name
471
+ return 1
472
+
473
+ # for obj in objs:
474
+ # if obj.data.shape_keys is not None:
475
+ # return 1
476
+ # # only 942!!!
477
+ # return 0
478
+
479
+
480
+ def delete(objs):
481
+ # check if the mesh object has skinning weight
482
+ for obj in objs:
483
+ vertex_groups = obj.vertex_groups
484
+ if not vertex_groups:
485
+ # delete the object
486
+ bpy.data.objects.remove(obj)
487
+ # print('delete!!!')
488
+ meshes = []
489
+ for obj in bpy.context.scene.objects:
490
+ if obj.type == "MESH":
491
+ meshes.append(obj)
492
+
493
+ return meshes
494
+
495
+
496
+ def merge_mesh(folder_path, export = None, save_join = True):
497
+ # output_path = os.path.join(folder_path, "rigging_norig.json")
498
+ # if os.path.exists(output_path):
499
+ # print("Already processed folder:", folder_path)
500
+ # return
501
+ bpy.ops.wm.read_homefile(use_empty=True)
502
+ try:
503
+ reload_json(folder_path)
504
+ except:
505
+ print("Error in reloading json file")
506
+ # remove the folder
507
+ os.system(f"rm -r {folder_path}")
508
+ return None, None
509
+
510
+ bpy.ops.object.select_all(action='DESELECT')
511
+ if export:
512
+ bpy.ops.wm.save_as_mainfile(filepath='reload_' + export)
513
+
514
+ meshes = []
515
+ for obj in bpy.context.scene.objects:
516
+ if obj.type == "MESH":
517
+ bpy.context.view_layer.objects.active = obj
518
+ obj.select_set(True)
519
+ meshes.append(obj)
520
+ print("meshes length", len(meshes))
521
+
522
+ bpy.ops.object.join()
523
+ if export:
524
+ bpy.ops.wm.save_as_mainfile(filepath='join_' + export)
525
+
526
+ meshes = []
527
+ for obj in bpy.context.scene.objects:
528
+ if obj.type == "MESH":
529
+ meshes.append(obj)
530
+ if len(meshes) != 1:
531
+ bpy.ops.wm.save_as_mainfile(filepath='join_f.blend')
532
+ assert len(meshes) == 1
533
+ # remesh(meshes[0])
534
+
535
+
536
+ if save_join:
537
+ obj_path = os.path.join(folder_path, "object.obj")
538
+ bpy.ops.wm.obj_export(filepath=obj_path, export_materials=False, export_uv=False, export_triangulated_mesh=True)
539
+ # mesh = trimesh.load(glb_file_path)
540
+ # mesh.export(obj_path, file_type='obj')
541
+
542
+
543
+ # save to json file
544
+ total_armature_count = 0
545
+ armature_obj = []
546
+ mesh_obj = []
547
+ for obj in bpy.context.scene.objects:
548
+ if obj.type == "ARMATURE":
549
+ total_armature_count += 1
550
+ armature_obj.append(obj)
551
+ if obj.type == "MESH":
552
+ mesh_obj.append(obj)
553
+ if total_armature_count == 0:
554
+ print("No rigging information for the file:", folder_path+"\n")
555
+ return None, None
556
+
557
+
558
+ ######### delete bones that are not in the vertex group
559
+ vertex_group_name = [group.name for group in mesh_obj[0].vertex_groups]
560
+ bpy.context.view_layer.objects.active = armature_obj[0]
561
+ bpy.ops.object.mode_set(mode='EDIT')
562
+ edit_bones = armature_obj[0].data.edit_bones
563
+ bone_delete = set([bone.name for bone in edit_bones]) - set(vertex_group_name)
564
+ print(f"Deleting {len(bone_delete)} bones")
565
+ for bone in bone_delete:
566
+ # if the bone is root, then do not delete it
567
+ if edit_bones[bone].parent == None:
568
+ # return len([1 for child in edit_bones[bone].children if child.name in bone_delete])
569
+ num_children = len(edit_bones[bone].children)
570
+ if num_children <= 1:
571
+ edit_bones.remove(edit_bones[bone])
572
+ continue
573
+ if num_children > 1:
574
+ center = mathutils.Vector((0, 0, 0))
575
+ for child in edit_bones[bone].children:
576
+ center += child.head
577
+ center /= num_children
578
+ min_dist = 1e9
579
+ for child in edit_bones[bone].children:
580
+ dist = (child.head - center).length
581
+ if dist < min_dist:
582
+ min_dist = dist
583
+ min_child = child
584
+ for child in edit_bones[bone].children:
585
+ if child != min_child:
586
+ child.parent = min_child
587
+ edit_bones.remove(edit_bones[bone])
588
+ continue
589
+ continue
590
+ # assign bone's children to bone's parent
591
+ bone_obj = edit_bones[bone]
592
+ for child in bone_obj.children:
593
+ child.parent = bone_obj.parent
594
+
595
+ edit_bones.remove(edit_bones[bone])
596
+ bpy.ops.object.mode_set(mode='OBJECT')
597
+
598
+ if export:
599
+ bpy.ops.wm.save_as_mainfile(filepath='delete_' + export)
600
+
601
+ mesh_obj = []
602
+ armature_obj = []
603
+ for obj in bpy.context.scene.objects:
604
+ if obj.type == "MESH":
605
+ mesh_obj.append(obj)
606
+ if obj.type == "ARMATURE":
607
+ armature_obj.append(obj)
608
+ assert len(mesh_obj) == 1
609
+ assert len(armature_obj) == 1
610
+
611
+ return mesh_obj, armature_obj
612
+
613
+
614
+ def process(file_path, obj_path=None, stamp=None, tex=False):
615
+ # check if obj_path exists
616
+ # if os.path.exists(obj_path + '/object.obj'):
617
+ # print('object.obj exists')
618
+ # return True
619
+ reset_scene()
620
+ load_object(file_path)
621
+ # bpy.ops.import_scene.gltf(filepath=glb_file_path)
622
+
623
+ # delete hierarchy collections['glTF_not_exported']
624
+ if 'glTF_not_exported' in bpy.data.collections:
625
+ print('DELETE glTF_not_exported')
626
+ bpy.data.collections.remove(bpy.data.collections['glTF_not_exported'])
627
+
628
+ if stamp is not None:
629
+ # Set the current frame to the stamp value
630
+ bpy.context.scene.frame_set(stamp)
631
+ print(f'Set the current frame to {stamp}')
632
+
633
+ # Ensure all objects are updated to this frame
634
+ bpy.context.view_layer.update()
635
+
636
+ mesh_obj = []
637
+ armature_obj = []
638
+ for obj in bpy.context.scene.objects:
639
+ if obj.type == "ARMATURE":
640
+ # if len(armature_obj) > 0:
641
+ # print(file_path, 'has more than 1 armature')
642
+ # return -2
643
+ armature_obj.append(obj)
644
+ # obj.show_in_front = True
645
+ armature_obj[-1].data.pose_position = 'POSE'
646
+ if obj.type == "MESH":
647
+ mesh_obj.append(obj)
648
+ # if obj.data.shape_keys is not None:
649
+ # return False
650
+
651
+ # mesh_obj = delete(mesh_obj)
652
+ # if len(mesh_obj) == 0:
653
+ # # print('zyhsb -1', file_path, obj_path)
654
+ # return -1
655
+ # return check(mesh_obj, armature_obj)
656
+
657
+
658
+ # total_vertices = np.array([len(obj.data.vertices) for obj in mesh_obj]).sum()
659
+ # if total_vertices < 1000: return
660
+ # if total_vertices > 10000: remesh(mesh_obj)
661
+
662
+
663
+ # bpy.ops.object.select_all(action='DESELECT')
664
+ # armature_obj.select_set(True)
665
+ # execute(bpy.context)
666
+
667
+
668
+ # normalize(mesh_obj)
669
+
670
+
671
+ mesh_obj = delete(mesh_obj)
672
+ if len(mesh_obj) == 0:
673
+ # print('zyhsb -1', file_path, obj_path)
674
+ return -1
675
+
676
+
677
+ save_json(obj_path, mesh_obj, armature_obj, arm_name=True)
678
+
679
+
680
+ if not tex:
681
+ save_mesh(obj_path + '/object.obj')
682
+ else:
683
+ save_mesh(obj_path + '/object.obj', mtl=True, obj_path=obj_path)
684
+
685
+
686
+ mesh_obj, armature_obj = merge_mesh(obj_path)
687
+ if mesh_obj is None or armature_obj is None:
688
+ # print('zyhsb -2', file_path, obj_path)
689
+ return -2
690
+
691
+
692
+ try:
693
+ normalize(mesh_obj)
694
+ except:
695
+ os.system(f"rm -r {obj_path}")
696
+ # print('zyhsb -3', file_path, obj_path)
697
+ return -3
698
+
699
+
700
+ save_json(obj_path, mesh_obj, armature_obj)
701
+
702
+ if not tex:
703
+ save_mesh(obj_path + '/object.obj')
704
+ else:
705
+ save_mesh(obj_path + '/object.obj', mtl=True, obj_path=obj_path)
706
+
707
+
708
+ return 1
709
+
710
+
711
+ if __name__ == '__main__':
712
+
713
+ parser = argparse.ArgumentParser()
714
+ parser.add_argument(
715
+ "--object_path",
716
+ type=str,
717
+ required=True,
718
+ help="Path to the object file",
719
+ )
720
+ parser.add_argument(
721
+ "--output_dir",
722
+ type=str,
723
+ required=True,
724
+ help="Path to the directory where the rendered images and metadata will be saved.",
725
+ )
726
+ parser.add_argument(
727
+ "--stamp",
728
+ type=int,
729
+ required=False,
730
+ help="Stamp to be used for the rendering.",
731
+ )
732
+ parser.add_argument(
733
+ "--tex",
734
+ type=bool,
735
+ required=False,
736
+ help="Save the texture.",
737
+ )
738
+ argv = sys.argv[sys.argv.index("--") + 1 :]
739
+ args = parser.parse_args(argv)
740
+
741
+ os.makedirs(args.output_dir, exist_ok=True)
742
+ stamp = args.stamp if args.stamp else None
743
+ print(f'Stamp: {stamp}')
744
+ result = process(args.object_path, obj_path=args.output_dir, stamp=stamp, tex=args.tex)
745
+ # import numpy as np
746
+ # os.makedirs(args.output_dir, exist_ok=True) # the directory may be removed
747
+ # np.save(args.output_dir + '/result.npy', np.array(result))
Anymate/checkpoints/.gitkeep ADDED
File without changes
Anymate/configs/.gitkeep ADDED
File without changes
Anymate/configs/conn.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 200
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: ce
11
+ mode: conn
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 16
19
+ trainset: Anymate_train
20
+ test_freq: 10
21
+
22
+ optimizer:
23
+ weight_decay: 1.0e-05
24
+ lr: 0.0001
25
+
26
+ model:
27
+ decoder: attendjoints_con_combine
28
+ encoder: bert
29
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
30
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
31
+ load_encoder: ''
32
+ num_joints: 96
33
+ out_channels: 3
34
+ width: 768
35
+ heads: 12
36
+ init_scale: 0.25
37
+ flash: False
38
+ use_checkpoint: False
39
+ qkv_bias: False
40
+ separate: False
Anymate/configs/conn_token.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 200
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: ce
11
+ mode: conn
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 16
19
+ trainset: Anymate_train
20
+ test_freq: 10
21
+
22
+ optimizer:
23
+ weight_decay: 1.0e-05
24
+ lr: 0.0001
25
+
26
+ model:
27
+ decoder: attendjoints_con_combine
28
+ encoder: bert
29
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
30
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
31
+ load_encoder: ''
32
+ num_joints: 96
33
+ out_channels: 3
34
+ width: 768
35
+ heads: 12
36
+ init_scale: 0.25
37
+ flash: False
38
+ use_checkpoint: False
39
+ qkv_bias: False
40
+ separate: False
Anymate/configs/diffusion.yaml ADDED
@@ -0,0 +1,49 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 4000
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: chamfer
11
+ mode: diffusion
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 16
19
+ trainset: Anymate_train
20
+ test_freq: 50
21
+ num_train_step: 100
22
+ num_training_points: 128
23
+ seed: 42
24
+
25
+ optimizer:
26
+ weight_decay: 1.0e-05
27
+ lr: 0.0001
28
+
29
+ model:
30
+ encoder: transformer
31
+ decoder: Cross_Attention_Diffusion
32
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
33
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
34
+ input_channels: 3
35
+ output_channels: 3
36
+ num_z: 16
37
+ num_x: 128
38
+ z_dim: 768
39
+ x_dim: 512
40
+ num_blocks: 4
41
+ num_compute_layers: 4
42
+ num_heads: 8
43
+ mlp_ratio: 4.0
44
+ qkv_bias: true
45
+ drop: 0.0
46
+ attn_drop: 0.0
47
+ drop_path: 0.0
48
+ num_latents: 16
49
+ use_projection: true
Anymate/configs/diffusion_concat.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 4000
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: chamfer
11
+ mode: diffusion
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 16
19
+ trainset: Anymate_train
20
+ test_freq: 1000
21
+ num_train_step: 100
22
+ num_training_points: 128
23
+ seed: 42
24
+
25
+ optimizer:
26
+ weight_decay: 1.0e-05
27
+ lr: 0.0001
28
+
29
+ model:
30
+ encoder: bert
31
+ decoder: Pointe_Diffusion
32
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
33
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
34
+ input_channels: 3
35
+ output_channels: 3
36
+ n_ctx: 128
37
+ width: 768
38
+ layers: 12
39
+ heads: 8
40
+ init_scale: 0.25
41
+ time_token_cond: true
42
+ cond_drop_prob: 0.1
43
+ use_projection: true
44
+
45
+
46
+
Anymate/configs/diffusion_cross.yaml ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 4000
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: chamfer
11
+ mode: diffusion
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 32
19
+ trainset: Anymate_train
20
+ test_freq: 1000
21
+ num_train_step: 100
22
+ num_training_points: 128
23
+ seed: 42
24
+
25
+ optimizer:
26
+ weight_decay: 1.0e-05
27
+ lr: 0.0001
28
+
29
+ model:
30
+ encoder: miche
31
+ decoder: Cross_Attention_Diffusion
32
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
33
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
34
+ input_channels: 3
35
+ output_channels: 3
36
+ num_z: 16
37
+ num_x: 128
38
+ z_dim: 768
39
+ x_dim: 512
40
+ num_blocks: 4
41
+ num_compute_layers: 4
42
+ num_heads: 8
43
+ mlp_ratio: 4.0
44
+ qkv_bias: true
45
+ drop: 0.0
46
+ attn_drop: 0.0
47
+ drop_path: 0.0
48
+ num_latents: 16
49
+ use_projection: true
50
+
51
+
Anymate/configs/joints.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 200
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: chamfer
11
+ mode: joints
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 16
19
+ trainset: Anymate_train
20
+ test_freq: 10
21
+
22
+ optimizer:
23
+ weight_decay: 1.0e-05
24
+ lr: 0.0001
25
+
26
+ model:
27
+ decoder: transformer_latent
28
+ encoder: bert
29
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
30
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
31
+ load_encoder: ''
32
+ num_joints: 96
33
+ out_channels: 3
34
+ width: 768
35
+ heads: 12
36
+ init_scale: 0.25
37
+ flash: False
38
+ use_checkpoint: False
39
+ qkv_bias: False
40
+ separate: False
Anymate/configs/joints_implicit.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 200
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: chamfer
11
+ mode: joints
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 8
19
+ trainset: Anymate_train
20
+ test_freq: 10
21
+
22
+ optimizer:
23
+ weight_decay: 1.0e-05
24
+ lr: 0.0001
25
+
26
+ model:
27
+ decoder: implicit_transformer
28
+ encoder: bert
29
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
30
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
31
+ load_encoder: ''
32
+ num_joints: 96
33
+ out_channels: 3
34
+ width: 768
35
+ heads: 12
36
+ init_scale: 0.25
37
+ flash: False
38
+ use_checkpoint: False
39
+ qkv_bias: False
40
+ separate: False
Anymate/configs/joints_triplane.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 200
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: chamfer
11
+ mode: joints
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 16
19
+ trainset: Anymate_train
20
+ test_freq: 10
21
+
22
+ optimizer:
23
+ weight_decay: 1.0e-05
24
+ lr: 0.0001
25
+
26
+ model:
27
+ decoder: triplane
28
+ encoder: bert
29
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
30
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
31
+ load_encoder: ''
32
+ num_joints: 96
33
+ out_channels: 3
34
+ width: 768
35
+ heads: 12
36
+ init_scale: 0.25
37
+ flash: False
38
+ use_checkpoint: False
39
+ qkv_bias: False
40
+ separate: False
Anymate/configs/skin.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 200
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: cos_clamp
11
+ mode: skin
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 16
19
+ trainset: Anymate_train
20
+ test_freq: 10
21
+
22
+ optimizer:
23
+ weight_decay: 1.0e-05
24
+ lr: 0.0001
25
+
26
+ model:
27
+ decoder: attendjoints_combine
28
+ encoder: bert
29
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
30
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
31
+ load_encoder: ''
32
+ num_joints: 96
33
+ out_channels: 3
34
+ width: 768
35
+ heads: 12
36
+ init_scale: 0.25
37
+ flash: False
38
+ use_checkpoint: False
39
+ qkv_bias: False
40
+ separate: False
Anymate/configs/skin_multi.yaml ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ args:
2
+ aggr: max
3
+ checkpoint: Anymate/checkpoints
4
+ device: cuda
5
+ epochs: 200
6
+ finetune: true
7
+ gamma: 0.2
8
+ input_normal: false
9
+ logdir: Anymate/logs
10
+ loss: cos_clamp
11
+ mode: skin
12
+ resume: ''
13
+ root: Anymate/data
14
+ schedule: []
15
+ start_epoch: 0
16
+ test_batch: 1
17
+ testset: Anymate_test
18
+ train_batch: 4
19
+ trainset: Anymate_train
20
+ test_freq: 10
21
+
22
+ optimizer:
23
+ weight_decay: 1.0e-05
24
+ lr: 0.0001
25
+
26
+ model:
27
+ decoder: attendjoints_multi
28
+ encoder: bert
29
+ config_path: ./ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml
30
+ ckpt_path: ./ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt
31
+ load_encoder: ''
32
+ num_joints: 96
33
+ out_channels: 3
34
+ width: 768
35
+ heads: 12
36
+ init_scale: 0.25
37
+ flash: False
38
+ use_checkpoint: False
39
+ qkv_bias: False
40
+ separate: False
Anymate/dataset.py ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.utils.data import Dataset
3
+ import os
4
+ import numpy as np
5
+ from Anymate.utils.dataset_utils import create_mask, index_to_sparse, index_to_sparse_con
6
+
7
+ def my_collate(batch):
8
+ # print(len(batch))
9
+ data = {}
10
+ for key in batch[0]:
11
+ if key=='vox' or key=='name' or key=='joints_num' or key=='skins_index' or key=='skins_weight' or key=='parent_index' or key=='conns' or key=='joints' or key=='bones' or key=='mesh_skins_index' or key=='mesh_skins_weight' or key=='mesh_pc' or key=='mesh_face':
12
+ data[key] = [sample[key] for sample in batch]
13
+ elif key=='pc':
14
+ data['points_cloud'] = torch.stack([sample['pc'] for sample in batch])
15
+ elif key=='skins':
16
+ continue
17
+ elif key=='bones_num':
18
+ data[key] = torch.tensor([sample['bones_num'] for sample in batch])
19
+ else:
20
+ data[key] = torch.stack([sample[key] for sample in batch])
21
+
22
+ if 'skins_index' in batch[0]:
23
+ max_joints = max(data['joints_num'])
24
+ max_bones = max(data['bones_num'])
25
+ # max_joints = 64
26
+ skin_list = [index_to_sparse(data['skins_index'][i].unsqueeze(0), data['skins_weight'][i].unsqueeze(0), [1, 8192, max_bones])[0] for i in range(len(data['skins_index']))]
27
+ data['skins'] = torch.stack(skin_list,dim=0)
28
+ data['joints_mask'] = torch.stack([create_mask(sample['joints_num'],max_len=max_joints) for sample in batch])
29
+ data['bones_mask'] = torch.stack([create_mask(sample['bones_num'],max_len=max_bones) for sample in batch])
30
+
31
+ if 'conns' in batch[0]:
32
+ max_joints = max(data['joints_num'])
33
+ conn_matrix = torch.zeros(len(data['conns']), 96, max_joints)
34
+ for i in range(len(data['conns'])):
35
+ for j in range(data['joints_num'][i]):
36
+ conn_matrix[i, j, data['conns'][i][j].long()] = 1
37
+ data['conns'] = conn_matrix
38
+ if 'joints' in batch[0]:
39
+ padded_joints_matrix = torch.ones(len(data['name']), 96, 3) * (-3)
40
+ for i in range(len(data['name'])):
41
+ padded_joints_matrix[i, :data['joints_num'][i], :] = data['joints'][i]
42
+ data['joints'] = padded_joints_matrix
43
+ if 'bones' in batch[0]:
44
+ padded_bones_matrix = torch.ones(len(data['name']), 64, 6) * (-3)
45
+ for i in range(len(data['name'])):
46
+ padded_bones_matrix[i, :data['bones_num'][i], :] = data['bones'][i]
47
+ data['bones'] = padded_bones_matrix
48
+ return data
49
+
50
+ class AnymateDataset(Dataset):
51
+ def __init__(self, name='Anymate_test', root='Anymate/data'):
52
+
53
+ if os.path.exists(os.path.join(root, name) + '.pt'):
54
+ self.data_list = torch.load(os.path.join(root, name) + '.pt')
55
+ else:
56
+ raise ValueError('Dataset not found at path: {}'.format(os.path.join(root, name) + '.pt'))
57
+
58
+ def __len__(self):
59
+ return len(self.data_list)
60
+
61
+ def __getitem__(self, idx):
62
+ return self.data_list[idx]
Anymate/get_checkpoints.sh ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd Anymate/checkpoints
2
+ mkdir joint
3
+ cd joint
4
+
5
+ echo "Downloading joint checkpoints..."
6
+ wget "https://huggingface.co/yfdeng/Anymate/resolve/main/checkpoints/joint/bert-transformer_latent-train-8gpu-finetune.pth.tar?download=true" -O bert-transformer_latent-train-8gpu-finetune.pth.tar
7
+
8
+ cd ..
9
+ mkdir conn
10
+ cd conn
11
+
12
+ echo "Downloading conn checkpoints..."
13
+ wget "https://huggingface.co/yfdeng/Anymate/resolve/main/checkpoints/conn/bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar?download=true" -O bert-attendjoints_con_combine-train-8gpu-finetune.pth.tar
14
+
15
+ cd ..
16
+ mkdir skin
17
+ cd skin
18
+
19
+ echo "Downloading skin checkpoints..."
20
+ wget "https://huggingface.co/yfdeng/Anymate/resolve/main/checkpoints/skin/bert-attendjoints_combine-train-8gpu-finetune.pth.tar?download=true" -O bert-attendjoints_combine-train-8gpu-finetune.pth.tar
21
+
22
+ echo "Finished downloading checkpoints!"
Anymate/get_datasets.sh ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ cd Anymate/data
2
+ wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_test.pt?download=true" -O Anymate_test.pt
3
+ wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_0.pt?download=true" -O Anymate_train_0.pt
4
+ wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_1.pt?download=true" -O Anymate_train_1.pt
5
+ wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_2.pt?download=true" -O Anymate_train_2.pt
6
+ wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_3.pt?download=true" -O Anymate_train_3.pt
7
+ wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_4.pt?download=true" -O Anymate_train_4.pt
8
+ wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_5.pt?download=true" -O Anymate_train_5.pt
9
+ wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_6.pt?download=true" -O Anymate_train_6.pt
10
+ wget "https://huggingface.co/datasets/yfdeng/Anymate/resolve/main/Anymate_train_7.pt?download=true" -O Anymate_train_7.pt
11
+
12
+ echo "Finished downloading datasets!"
Anymate/model.py ADDED
@@ -0,0 +1,360 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ThirdParty.michelangelo.utils.misc import get_config_from_file, instantiate_from_config
4
+ # from ThirdParty.PointLLM.pointllm.model.pointllm import PointLLMLlamaForCausalLM
5
+ from ThirdParty.michelangelo.models.modules.distributions import DiagonalGaussianDistribution
6
+ from ThirdParty.michelangelo.models.modules.embedder import components_from_spherical_harmonics
7
+ from Anymate.utils.diffusion_encoder import TransformerEncoder
8
+ from Anymate.models.joint import TransformerDecoder, ImplicitTransformerDecoder, TriPlaneDecoder
9
+ from Anymate.models.conn import AttendjointsDecoder_con_combine, AttendjointsDecoder_con_token
10
+ from Anymate.models.skin import AttendjointsDecoder_combine, AttendjointsDecoder_multi
11
+ from Anymate.models.diffusion import Pointe_Diffusion, Cross_Attention_Diffusion
12
+
13
+ class Encoder(nn.Module):
14
+ def __init__(self,
15
+ only_embed = True,
16
+ config_path = './ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml',
17
+ ckpt_path = './ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt',
18
+ num_latents = 257,
19
+ device = 'cuda'):
20
+
21
+ super().__init__()
22
+
23
+ model_config = get_config_from_file(config_path)
24
+ if hasattr(model_config, "model"):
25
+ model_config = model_config.model
26
+
27
+ if ckpt_path is not None:
28
+ model = instantiate_from_config(model_config, ckpt_path=ckpt_path)
29
+ else:
30
+ model = instantiate_from_config(model_config)
31
+ model.model.shape_model.encoder.num_latents = num_latents
32
+ model.model.shape_model.encoder.query = nn.Parameter(torch.randn((num_latents, 768), device=device, dtype=torch.float32) * 0.02)
33
+
34
+ self.shape_projection = model.model.shape_projection
35
+ self.encoder = model.model.shape_model.encoder
36
+ self.normal_embedder = components_from_spherical_harmonics
37
+ old_linear_proj = self.encoder.input_proj
38
+ self.encoder.input_proj = nn.Linear(old_linear_proj.in_features + 25, old_linear_proj.out_features)
39
+ self.encoder.input_proj.weight.data[:, :old_linear_proj.in_features] = old_linear_proj.weight.data[:, :old_linear_proj.in_features].clone()
40
+ self.encoder.input_proj.bias.data = old_linear_proj.bias.data.clone()
41
+ if not only_embed:
42
+ self.embed_dim = model.model.shape_model.embed_dim
43
+ self.pre_kl = model.model.shape_model.pre_kl
44
+ self.post_kl = model.model.shape_model.post_kl
45
+ self.transformer = model.model.shape_model.transformer
46
+
47
+
48
+ def encode_latents(self,
49
+ pc: torch.FloatTensor,
50
+ feats = None):
51
+
52
+ feats_embed = self.normal_embedder(feats)
53
+ feats = torch.cat([feats, feats_embed], dim=-1)
54
+
55
+ x, _ = self.encoder(pc, feats)
56
+
57
+ shape_embed = x[:, 0]
58
+ latents = x[:, 1:]
59
+
60
+ return shape_embed, latents
61
+
62
+
63
+ def encode_shape_embed(self, surface, return_latents: bool = False):
64
+ """
65
+
66
+ Args:
67
+ surface (torch.FloatTensor): [bs, n, 3 + c]
68
+ return_latents (bool):
69
+
70
+ Returns:
71
+ x (torch.FloatTensor): [bs, projection_dim]
72
+ shape_latents (torch.FloatTensor): [bs, m, d]
73
+ """
74
+
75
+ pc = surface[..., 0:3]
76
+ feats = surface[..., 3:]
77
+
78
+ shape_embed, shape_latents = self.encode_latents(pc, feats)
79
+ x = shape_embed @ self.shape_projection
80
+
81
+ if return_latents:
82
+ return x, shape_latents
83
+ else:
84
+ return x
85
+
86
+
87
+ def encode_kl_embed(self, latents: torch.FloatTensor, sample_posterior: bool = True):
88
+ posterior = None
89
+ if self.embed_dim > 0:
90
+ moments = self.pre_kl(latents)
91
+ posterior = DiagonalGaussianDistribution(moments, feat_dim=-1)
92
+
93
+ if sample_posterior:
94
+ kl_embed = posterior.sample()
95
+ else:
96
+ kl_embed = posterior.mode()
97
+ else:
98
+ kl_embed = latents
99
+
100
+ return kl_embed, posterior
101
+
102
+
103
+ def decode(self, latents: torch.FloatTensor):
104
+ latents = self.post_kl(latents)
105
+ return self.transformer(latents)
106
+
107
+
108
+ class EncoderDecoder(nn.Module):
109
+ def __init__(self,
110
+ decoder = 'mlp',
111
+ encoder = 'miche',
112
+ config_path = './ThirdParty/michelangelo/configs/aligned_shape_latents/shapevae-256.yaml',
113
+ ckpt_path = './ThirdParty/michelangelo/checkpoints/aligned_shape_latents/shapevae-256.ckpt',
114
+ load_encoder = '',
115
+ num_joints = 96,
116
+ out_channels = 3,
117
+ width = 768,
118
+ device = 'cuda',
119
+ dtype = torch.float32,
120
+ heads = 12,
121
+ init_scale: float = 0.25,
122
+ flash = False,
123
+ use_checkpoint = False,
124
+ qkv_bias = False,
125
+ separate = False,
126
+ **kwargs):
127
+
128
+ super().__init__()
129
+ self.decoder_name = decoder
130
+ self.encoder_name = encoder
131
+ self.dtype = dtype
132
+ self.load_encoder = load_encoder
133
+
134
+ if decoder == 'transformer_latent':
135
+ self.only_embed = False
136
+ self.return_latents = True
137
+ self.decoder = TransformerDecoder(
138
+ num_latents = num_joints,
139
+ out_channels = out_channels,
140
+ width = width,
141
+ device = device,
142
+ dtype = dtype,
143
+ heads = heads,
144
+ init_scale = init_scale,
145
+ flash = flash,
146
+ use_checkpoint = use_checkpoint,
147
+ qkv_bias = qkv_bias
148
+ )
149
+ elif decoder == 'implicit_transformer':
150
+ self.only_embed = False
151
+ self.return_latents = True
152
+ self.decoder = ImplicitTransformerDecoder(
153
+ device = device,
154
+ dtype = dtype,
155
+ num_latents = 257,
156
+ out_channels = 1,
157
+ width = width,
158
+ heads = heads,
159
+ init_scale = init_scale,
160
+ flash = flash,
161
+ use_checkpoint = use_checkpoint,
162
+ qkv_bias = qkv_bias
163
+ )
164
+ elif decoder == 'triplane': #consider add these parameters to config
165
+ self.only_embed = True
166
+ self.return_latents = False
167
+ self.decoder = TriPlaneDecoder(
168
+ z_dim = 768,
169
+ c_dim = 0,
170
+ w_dim = 768,
171
+ mapping_kwargs = {'num_layers': 2},
172
+ synthesis_kwargs = {'num_fp16_res': 0, 'conv_clamp': None, 'fused_modconv_default': 'inference_only'}
173
+ )
174
+
175
+ elif decoder == 'Pointe_Diffusion':
176
+ self.only_embed = False
177
+ self.return_latents = True
178
+ self.decoder = Pointe_Diffusion(**kwargs)
179
+
180
+ elif decoder == 'Cross_Attention_Diffusion':
181
+ self.only_embed = False
182
+ self.return_latents = True
183
+ self.decoder = Cross_Attention_Diffusion(**kwargs)
184
+
185
+ elif decoder == 'attendjoints_combine':
186
+ self.only_embed = False
187
+ self.return_latents = True
188
+ self.decoder = AttendjointsDecoder_combine(
189
+ width = width,
190
+ device = device,
191
+ dtype = dtype,
192
+ heads = heads,
193
+ init_scale = init_scale,
194
+ flash = flash,
195
+ use_checkpoint = use_checkpoint,
196
+ separate = separate,
197
+ qkv_bias = qkv_bias
198
+ )
199
+ elif decoder == 'attendjoints_multi':
200
+ self.only_embed = False
201
+ self.return_latents = True
202
+ self.decoder = AttendjointsDecoder_multi(
203
+ width = width,
204
+ device = device,
205
+ dtype = dtype,
206
+ heads = heads,
207
+ init_scale = init_scale,
208
+ flash = flash,
209
+ use_checkpoint = use_checkpoint,
210
+ qkv_bias = qkv_bias,
211
+ separate=separate
212
+ )
213
+ elif decoder == 'attendjoints_con_combine':
214
+ self.only_embed = False
215
+ self.return_latents = True
216
+ self.decoder = AttendjointsDecoder_con_combine(
217
+ width = width,
218
+ device = device,
219
+ dtype = dtype,
220
+ heads = heads,
221
+ init_scale = init_scale,
222
+ flash = flash,
223
+ use_checkpoint = use_checkpoint,
224
+ qkv_bias = qkv_bias
225
+ )
226
+ elif decoder == 'attendjoints_con_token':
227
+ self.only_embed = False
228
+ self.return_latents = True
229
+ self.decoder = AttendjointsDecoder_con_token(
230
+ width = width,
231
+ device = device,
232
+ dtype = dtype,
233
+ heads = heads,
234
+ init_scale = init_scale,
235
+ flash = flash,
236
+ use_checkpoint = use_checkpoint,
237
+ qkv_bias = qkv_bias,
238
+ separate = separate
239
+ )
240
+
241
+ if encoder == 'miche':
242
+ if not self.load_encoder:
243
+ self.encoder = Encoder(only_embed=self.only_embed, config_path=config_path, ckpt_path=ckpt_path, device=device)
244
+ else:
245
+ self.encoder = Encoder(only_embed=self.only_embed, config_path=config_path, ckpt_path=None, device=device)
246
+ try:
247
+ print("=> loading encoder checkpoint '{}'".format(self.load_encoder))
248
+ checkpoint = torch.load(self.load_encoder, map_location='cpu')
249
+ state_dict = {k[8:]: v for k, v in checkpoint['state_dict'].items() if k.startswith('encoder')}
250
+ self.encoder.load_state_dict(state_dict)
251
+ print("=> loaded encoder checkpoint '{}'".format(self.load_encoder))
252
+ except:
253
+ print("=> no encoder checkpoint found at '{}'".format(self.load_encoder))
254
+ if self.load_encoder:
255
+ self.point_proj = nn.Sequential(
256
+ nn.Linear(768, 768, dtype=dtype),
257
+ nn.GELU(),
258
+ nn.Linear(768, 768, dtype=dtype),
259
+ )
260
+
261
+ if encoder == 'bert':
262
+ # model_name = 'RunsenXu/PointLLM_7B_v1.2'
263
+ # model = PointLLMLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=True, torch_dtype=dtype)
264
+ # self.encoder = model.model.point_backbone.to(device)
265
+ # model = None
266
+ from ThirdParty.PointLLM.pointllm.model import PointTransformer
267
+ from ThirdParty.PointLLM.pointllm.utils import cfg_from_yaml_file
268
+ import os
269
+ # address of config file, in the same dir of this file
270
+ point_bert_config_name = "PointTransformer_8192point_2layer" # * default for v1.2, v1.1 uses PointTransformer_base_8192point.yaml
271
+ point_bert_config_addr = os.path.join("./ThirdParty/PointLLM/pointllm/model/pointbert/PointTransformer_8192point_2layer.yaml")
272
+ print(f"Loading PointBERT config from {point_bert_config_addr}.")
273
+ point_bert_config = cfg_from_yaml_file(point_bert_config_addr)
274
+ point_bert_config.model.point_dims = 6
275
+ use_max_pool = getattr(point_bert_config.model, "use_max_pool", False) # * default is false
276
+
277
+ self.encoder = PointTransformer(point_bert_config.model, use_max_pool=use_max_pool).to(device)
278
+ if self.return_latents:
279
+ self.point_proj = nn.Sequential(
280
+ nn.Linear(384, 512, dtype=dtype),
281
+ nn.GELU(),
282
+ nn.Linear(512, 512, dtype=dtype),
283
+ nn.GELU(),
284
+ nn.Linear(512, 768, dtype=dtype)
285
+ )
286
+ else:
287
+ self.point_proj = nn.ModuleList([
288
+ nn.Sequential(
289
+ nn.Linear(384, 512, dtype=dtype),
290
+ nn.GELU(),
291
+ nn.Linear(512, 512, dtype=dtype),
292
+ nn.GELU(),
293
+ nn.Linear(512, 768, dtype=dtype)
294
+ ),
295
+ nn.Linear(513, 1, dtype=dtype)
296
+ ])
297
+ if encoder == 'transformer':
298
+ self.points_cloud_embed = nn.Linear(
299
+ 768, 768, device=device, dtype=dtype
300
+ )
301
+ self.encoder = TransformerEncoder(device=device,dtype=dtype, num_latents=kwargs['num_latents'])
302
+
303
+
304
+
305
+ def encode(self, data, device='cuda'):
306
+ assert self.encoder_name in ['miche', 'bert', 'transformer'], f'Encoder {self.encoder_name} not supported'
307
+ if self.encoder_name == 'miche':
308
+ surface = data['points_cloud'].to(self.dtype).to(device)
309
+
310
+ # encoding
311
+ shape_embed, shape_latents = self.encoder.encode_shape_embed(surface, return_latents=True) # ShapeAsLatentPerceiver.encode_latents(): encoder
312
+
313
+ if self.only_embed:
314
+ if self.return_latents:
315
+ if self.load_encoder:
316
+ return self.point_proj(torch.cat([shape_embed.unsqueeze(1), shape_latents], dim=1))
317
+ return torch.cat([shape_embed.unsqueeze(1), shape_latents], dim=1) # torch.Size([bs, 257, 768]
318
+ return shape_embed # shape_embed: torch.Size([bs, 768])
319
+
320
+ shape_zq, posterior = self.encoder.encode_kl_embed(shape_latents) # ShapeAsLatentPerceiver.encode_kl_embed(): pre_kl + DiagonalGaussianDistribution()
321
+ # shape_zq, posterior = self.encoder.encode_kl_embed(shape_latents, sample_posterior=False) # not sample
322
+ # pretrained weight has 0 +- 0.7 mean and 0.5 +- 0.5 std
323
+ # trained weight has 0 +- 1.8 mean and 0.1 +- 0.1 std
324
+ # generally okay
325
+
326
+ latents = self.encoder.decode(shape_zq) # ShapeAsLatentPerceiver.decode(): post_kl + transformer
327
+
328
+ if not self.return_latents:
329
+ latents = torch.cat([shape_latents, latents], dim=1) # torch.Size([bs, 512, 768])
330
+
331
+ if self.load_encoder:
332
+ return self.point_proj(torch.cat([shape_embed.unsqueeze(1), latents], dim=1))
333
+ return torch.cat([shape_embed.unsqueeze(1), latents], dim=1) # torch.Size([bs, 257 / 513, 768])
334
+
335
+ if self.encoder_name == 'bert':
336
+ points = data['points_cloud'].to(self.dtype).to(device)
337
+ points = points[:, :, :3] / 2
338
+ points = torch.cat([points, torch.zeros_like(points)], dim=-1)
339
+ points = self.encoder(points)
340
+
341
+ if self.return_latents:
342
+ points = self.point_proj(points)
343
+ else:
344
+ points = self.point_proj[0](points)
345
+ points = self.point_proj[1](points.permute(0, 2, 1)).squeeze(-1)
346
+ return points
347
+
348
+ if self.encoder_name == 'transformer':
349
+ points = data['points_cloud'].to(self.dtype).to(device)
350
+ cond = self.encoder.encode_pc(points)
351
+ cond = self.points_cloud_embed(cond)
352
+ return cond
353
+
354
+ def forward(self, data, device='cuda', downsample=False, **kwargs):
355
+ latents = self.encode(data, device)
356
+ # print('latents shape', latents.shape)
357
+
358
+ logits = self.decoder(latents, data, device=device, downsample=downsample,**kwargs)
359
+
360
+ return logits
Anymate/models/__init__.py ADDED
File without changes
Anymate/models/conn.py ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock, ResidualAttentionBlock, Transformer
4
+ from ThirdParty.michelangelo.models.modules.embedder import FourierEmbedder, components_from_spherical_harmonics
5
+
6
+ class AttendjointsDecoder_con_combine(nn.Module):
7
+ def __init__(self,
8
+ width = 768,
9
+ layers = 2,
10
+ device = 'cuda',
11
+ dtype = torch.float32,
12
+ heads = 12,
13
+ init_scale: float = 0.25,
14
+ flash = False,
15
+ use_checkpoint = False,
16
+ qkv_bias = False,
17
+ num_freqs: int = 8,
18
+ include_pi: bool = True,
19
+ separate = False,
20
+ use_mask = True):
21
+
22
+ super().__init__()
23
+
24
+ self.use_checkpoint = use_checkpoint
25
+ self.separate = separate
26
+ self.use_mask = use_mask
27
+ # self.num_latents = num_latents
28
+
29
+ # self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
30
+
31
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
32
+ self.co_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
33
+
34
+ # self.proj_attn = nn.Linear(width, width, device=device, dtype=dtype)
35
+
36
+ self.cross_attn = nn.ModuleList([ResidualCrossAttentionBlock(
37
+ device=device,
38
+ dtype=dtype,
39
+ width=width,
40
+ heads=heads,
41
+ init_scale=init_scale,
42
+ qkv_bias=qkv_bias,
43
+ flash=flash,
44
+ ) for _ in range(layers)])
45
+
46
+ self.self_attn = nn.ModuleList([ResidualAttentionBlock(
47
+ device=device,
48
+ dtype=dtype,
49
+ n_ctx=-1,
50
+ width=width,
51
+ heads=heads,
52
+ init_scale=init_scale,
53
+ qkv_bias=qkv_bias,
54
+ flash=flash,
55
+ ) for _ in range(layers * 2)])
56
+
57
+ # self.joint_embed_proj = nn.ModuleList([nn.Linear(width, width, device=device, dtype=dtype) for _ in range(layers)])
58
+
59
+
60
+ self.q_proj = nn.Linear(width, width, device=device, dtype=dtype)
61
+ self.k_proj = nn.Linear(width, width, device=device, dtype=dtype)
62
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
63
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
64
+
65
+ # self.last_cross_attn = ResidualCrossAttentionBlock(
66
+ # device=device,
67
+ # dtype=dtype,
68
+ # width=width,
69
+ # heads=heads,
70
+ # init_scale=init_scale,
71
+ # qkv_bias=qkv_bias,
72
+ # flash=flash,
73
+ # )
74
+ # self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
75
+ # self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
76
+
77
+ def forward(self, latents, data=None, device='cuda', downsample=None, dtype=torch.float32):
78
+
79
+ joints = data['joints'].to(device)
80
+ max_joints = max(data['joints_num'])
81
+ joints = joints[:, :max_joints, :3]
82
+
83
+ joints_embeds = self.fourier_embedder(joints)
84
+ joints_embeds = self.co_proj(joints_embeds)
85
+
86
+ joints_num = joints_embeds.shape[-2]
87
+
88
+ x = [joints_embeds, joints_embeds.clone()]
89
+
90
+ for i in range(2):
91
+ for j, layer in enumerate(self.cross_attn):
92
+
93
+ x[i] = layer(x[i], latents)
94
+
95
+ if self.use_mask:
96
+ x[i] = self.self_attn[2*i+j](x[i], mask=data['joints_mask'].to(device))
97
+ else:
98
+ x[i] = self.self_attn[2*i+j](x[i])
99
+
100
+ # Dot Product between points and joints
101
+ logits = torch.einsum('bnc,bmc->bnm', self.k_proj(self.ln_1(x[0])), self.q_proj(self.ln_2(x[1]))) # (b, n, m)
102
+
103
+ if self.use_mask:
104
+ mask = data['joints_mask'].to(device)
105
+ logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
106
+
107
+ return logits
108
+
109
+ class AttendjointsDecoder_con_token(nn.Module):
110
+ def __init__(self,
111
+ width = 768,
112
+ layers = 4,
113
+ device = 'cuda',
114
+ dtype = torch.float32,
115
+ heads = 12,
116
+ init_scale: float = 0.25,
117
+ flash = False,
118
+ use_checkpoint = False,
119
+ qkv_bias = False,
120
+ num_freqs: int = 8,
121
+ include_pi: bool = True,
122
+ head_token_length =128,
123
+ separate = False,
124
+ use_mask = True):
125
+
126
+ super().__init__()
127
+
128
+ self.use_checkpoint = use_checkpoint
129
+ self.use_mask = use_mask
130
+ self.layer_norm = nn.LayerNorm(width)
131
+ self.head_token = nn.Parameter(torch.randn((1, 1, head_token_length), device=device, dtype=dtype) * 0.02)
132
+ self.tail_token = nn.Parameter(torch.randn((1, 1, head_token_length), device=device, dtype=dtype) * 0.02)
133
+ self.head_mlp = nn.ModuleList([
134
+ nn.Linear(width + head_token_length, 512, device=device, dtype=dtype),
135
+ nn.Linear(512, 512, device=device, dtype=dtype),
136
+ nn.Linear(512, width, device=device, dtype=dtype),
137
+ nn.LayerNorm(width)
138
+
139
+ ])
140
+ self.tail_mlp = nn.ModuleList([
141
+ nn.Linear(width + head_token_length, 512, device=device, dtype=dtype),
142
+ nn.Linear(512, 512, device=device, dtype=dtype),
143
+ nn.Linear(512, width, device=device, dtype=dtype),
144
+ nn.LayerNorm(width)
145
+ ])
146
+
147
+ self.self_attn = Transformer(
148
+ device=device,
149
+ dtype=dtype,
150
+ n_ctx=-1,
151
+ width=width,
152
+ layers=layers,
153
+ heads=heads,
154
+ init_scale=init_scale,
155
+ qkv_bias=qkv_bias,
156
+ flash=flash,
157
+ use_checkpoint=False,
158
+ )
159
+ self.separate = separate
160
+ self.normal_embedder = components_from_spherical_harmonics
161
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
162
+ self.joints_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
163
+ self.output_proj_joints = nn.Linear(width, width, device=device, dtype=dtype)
164
+
165
+ def forward(self, latents, data=None,device='cuda', downsample=None, dtype='float32'):
166
+ joints = data['joints'].to(device)
167
+ max_joints = max(data['joints_num'])
168
+ joints = joints[:, :max_joints, :3]
169
+ joints_embeds_fourier = self.fourier_embedder(joints)
170
+ joints_embeds = self.joints_proj(joints_embeds_fourier)
171
+ # Concatenate embeddings
172
+ x = torch.cat([joints_embeds, latents], dim=-2) # (b, max_joint+token_num, c)
173
+ # Pass through self-attention
174
+ if self.use_mask:
175
+ mask = data['mask'].to(device)
176
+ append_size = x.shape[1]-mask.shape[1] # the zero needs to append after mask
177
+ batch_size = mask.shape[0]
178
+
179
+ mask_extend = torch.ones((batch_size,append_size)).to(device)
180
+ mask = torch.cat([mask,mask_extend],dim=-1).to(device)
181
+
182
+ x = self.self_attn(x,mask)
183
+ else:
184
+ x = self.self_attn(x)
185
+ joints, _= x.split([joints_embeds.shape[1], latents.shape[1]], dim=1)
186
+ joints = self.output_proj_joints(self.layer_norm(joints))
187
+ joints_head = torch.concat([joints, self.head_token.repeat(joints.shape[0],joints.shape[1],1)], dim=-1)
188
+ joints_tail = torch.concat([joints, self.tail_token.repeat(joints.shape[0],joints.shape[1],1)], dim=-1)
189
+ for layer in self.head_mlp:
190
+ joints_head = layer(joints_head)
191
+ for layer in self.tail_mlp:
192
+ joints_tail = layer(joints_tail)
193
+ logits = torch.einsum('bik,bjk->bij', joints_head, joints_tail)
194
+
195
+ return logits
Anymate/models/diffusion.py ADDED
@@ -0,0 +1,483 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ F"""
2
+ Adapted from: https://github.com/openai/openai/blob/55363aa496049423c37124b440e9e30366db3ed6/orc/orc/diffusion/vit.py
3
+ """
4
+
5
+ import math
6
+ from typing import Any, Dict, Iterable, List, Optional, Sequence, Tuple, Union, Callable
7
+
8
+ import torch
9
+ import torch.nn as nn
10
+
11
+ from einops import repeat
12
+ from Anymate.utils.diffusion_utils import *
13
+ from ThirdParty.michelangelo.models.modules.transformer_blocks import Transformer, ResidualCrossAttentionBlock
14
+
15
+ from diffusers import DDPMScheduler, DDIMScheduler
16
+ from sklearn.cluster import DBSCAN
17
+
18
+ def init_linear(l, stddev):
19
+ nn.init.normal_(l.weight, std=stddev)
20
+ if l.bias is not None:
21
+ nn.init.constant_(l.bias, 0.0)
22
+
23
+ class projection_transformer(nn.Module):
24
+ def __init__(self, num_latents=16, width = 16, heads=8, dtype = torch.float32):
25
+ super().__init__()
26
+ self.num_latents = num_latents
27
+ self.query = nn.Parameter(torch.randn((num_latents, width), dtype=dtype) * 0.02)
28
+
29
+ self.cross_attn = ResidualCrossAttentionBlock(
30
+ device= 'cuda',
31
+ dtype=dtype,
32
+ width=width,
33
+ heads=heads,
34
+ init_scale=0.25,
35
+ qkv_bias=True,
36
+ flash=False,
37
+ )
38
+ self.output_proj = nn.Linear(width, width,dtype=dtype)
39
+
40
+ def forward(self, latents):
41
+ bs = latents.shape[0]
42
+ query = repeat(self.query, "m c -> b m c", b=bs)
43
+ embed = self.cross_attn(query, latents)
44
+ logits = self.output_proj(embed)
45
+
46
+ return logits
47
+
48
+ def timestep_embedding(timesteps, dim, max_period=10000):
49
+ """
50
+ Create sinusoidal timestep embeddings.
51
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
52
+ These may be fractional.
53
+ :param dim: the dimension of the output.
54
+ :param max_period: controls the minimum frequency of the embeddings.
55
+ :return: an [N x dim] Tensor of positional embeddings.
56
+ """
57
+ half = dim // 2
58
+ freqs = torch.exp(
59
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
60
+ ).to(device=timesteps.device)
61
+ args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
62
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
63
+ if dim % 2:
64
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
65
+ return embedding
66
+
67
+ class MultiheadAttention(nn.Module):
68
+ def __init__(
69
+ self,
70
+ *,
71
+ dtype: torch.dtype,
72
+ n_ctx: int,
73
+ width: int,
74
+ heads: int,
75
+ init_scale: float,
76
+ ):
77
+ super().__init__()
78
+ self.n_ctx = n_ctx
79
+ self.width = width
80
+ self.heads = heads
81
+ self.c_qkv = nn.Linear(width, width * 3, dtype=dtype)
82
+ self.c_proj = nn.Linear(width, width, dtype=dtype)
83
+ self.attention = QKVMultiheadAttention(dtype=dtype, heads=heads, n_ctx=n_ctx)
84
+ init_linear(self.c_qkv, init_scale)
85
+ init_linear(self.c_proj, init_scale)
86
+
87
+ def forward(self, x):
88
+ x = self.c_qkv(x)
89
+ x = self.attention(x)
90
+ x = self.c_proj(x)
91
+ return x
92
+
93
+ class MLP(nn.Module):
94
+ def __init__(self, *, dtype: torch.dtype, width: int, init_scale: float):
95
+ super().__init__()
96
+ self.width = width
97
+ self.c_fc = nn.Linear(width, width * 4, dtype=dtype)
98
+ self.c_proj = nn.Linear(width * 4, width, dtype=dtype)
99
+ self.gelu = nn.GELU()
100
+ init_linear(self.c_fc, init_scale)
101
+ init_linear(self.c_proj, init_scale)
102
+
103
+ def forward(self, x):
104
+ return self.c_proj(self.gelu(self.c_fc(x)))
105
+
106
+ class QKVMultiheadAttention(nn.Module):
107
+ def __init__(self, *, dtype: torch.dtype, heads: int, n_ctx: int):
108
+ super().__init__()
109
+ self.dtype = dtype
110
+ self.heads = heads
111
+ self.n_ctx = n_ctx
112
+
113
+ def forward(self, qkv):
114
+ bs, n_ctx, width = qkv.shape
115
+ attn_ch = width // self.heads // 3
116
+ scale = 1 / math.sqrt(math.sqrt(attn_ch))
117
+ qkv = qkv.view(bs, n_ctx, self.heads, -1)
118
+ q, k, v = torch.split(qkv, attn_ch, dim=-1)
119
+ weight = torch.einsum(
120
+ "bthc,bshc->bhts", q * scale, k * scale
121
+ ) # More stable with f16 than dividing afterwards
122
+ wdtype = weight.dtype
123
+ weight = torch.softmax(weight.float(), dim=-1).type(wdtype)
124
+ return torch.einsum("bhts,bshc->bthc", weight, v).reshape(bs, n_ctx, -1)
125
+
126
+ class ResidualAttentionBlock(nn.Module):
127
+ def __init__(
128
+ self,
129
+ *,
130
+ dtype: torch.dtype,
131
+ n_ctx: int,
132
+ width: int,
133
+ heads: int,
134
+ init_scale: float = 1.0,
135
+ ):
136
+ super().__init__()
137
+
138
+ self.attn = MultiheadAttention(
139
+ dtype=dtype,
140
+ n_ctx=n_ctx,
141
+ width=width,
142
+ heads=heads,
143
+ init_scale=init_scale,
144
+ )
145
+ self.ln_1 = nn.LayerNorm(width, dtype=dtype)
146
+ self.mlp = MLP(dtype=dtype, width=width, init_scale=init_scale)
147
+ self.ln_2 = nn.LayerNorm(width, dtype=dtype)
148
+
149
+ def forward(self, x: torch.Tensor):
150
+ x = x + self.attn(self.ln_1(x))
151
+ x = x + self.mlp(self.ln_2(x))
152
+ return x
153
+
154
+ class Transformer(nn.Module):
155
+ def __init__(
156
+ self,
157
+ *,
158
+ dtype: torch.dtype,
159
+ n_ctx: int,
160
+ width: int,
161
+ layers: int,
162
+ heads: int,
163
+ init_scale: float = 0.25,
164
+ ):
165
+ super().__init__()
166
+ self.n_ctx = n_ctx
167
+ self.width = width
168
+ self.layers = layers
169
+ init_scale = init_scale * math.sqrt(1.0 / width)
170
+ self.resblocks = nn.ModuleList(
171
+ [
172
+ ResidualAttentionBlock(
173
+ dtype=dtype,
174
+ n_ctx=n_ctx,
175
+ width=width,
176
+ heads=heads,
177
+ init_scale=init_scale,
178
+ )
179
+ for _ in range(layers)
180
+ ]
181
+ )
182
+
183
+ def forward(self, x: torch.Tensor):
184
+ for block in self.resblocks:
185
+ x = block(x)
186
+ return x
187
+
188
+ class PointDiffusionTransformer(nn.Module):
189
+ def __init__(
190
+ self,
191
+ *,
192
+ dtype: torch.dtype,
193
+ input_channels: int = 3,
194
+ output_channels: int = 3,
195
+ n_ctx: int = 1024,
196
+ width: int = 768,
197
+ layers: int = 12,
198
+ heads: int = 8,
199
+ init_scale: float = 0.25,
200
+ time_token_cond: bool = True,
201
+ ):
202
+ super().__init__()
203
+ self.input_channels = input_channels
204
+ self.output_channels = output_channels
205
+ self.n_ctx = n_ctx
206
+ self.time_token_cond = time_token_cond
207
+ self.time_embed = MLP(
208
+ dtype=dtype, width=width, init_scale=init_scale * math.sqrt(1.0 / width)
209
+ )
210
+ self.ln_pre = nn.LayerNorm(width, dtype=dtype)
211
+ self.backbone = Transformer(
212
+ dtype=dtype,
213
+ n_ctx=n_ctx + int(time_token_cond),
214
+ width=width,
215
+ layers=layers,
216
+ heads=heads,
217
+ init_scale=init_scale,
218
+ )
219
+ self.ln_post = nn.LayerNorm(width,dtype=dtype)
220
+ self.input_proj = nn.Linear(input_channels, width, dtype=dtype)
221
+ self.output_proj = nn.Linear(width, output_channels,dtype=dtype)
222
+ with torch.no_grad():
223
+ self.output_proj.weight.zero_()
224
+ self.output_proj.bias.zero_()
225
+
226
+ def forward(self, x: torch.Tensor, t: torch.Tensor):
227
+ """
228
+ :param x: an [N x C x T] tensor.
229
+ :param t: an [N] tensor.
230
+ :return: an [N x C' x T] tensor.
231
+ """
232
+ assert x.shape[-1] == self.n_ctx
233
+ t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
234
+ return self._forward_with_cond(x, [(t_embed, self.time_token_cond)])
235
+
236
+ def _forward_with_cond(
237
+ self, x: torch.Tensor, cond_as_token: List[Tuple[torch.Tensor, bool]]
238
+ ) -> torch.Tensor:
239
+ h = self.input_proj(x.permute(0, 2, 1)) # NCL -> NLC
240
+ for emb, as_token in cond_as_token:
241
+ if not as_token:
242
+ h = h + emb[:, None]
243
+ extra_tokens = [
244
+ (emb[:, None] if len(emb.shape) == 2 else emb)
245
+ for emb, as_token in cond_as_token
246
+ if as_token
247
+ ]
248
+ if len(extra_tokens):
249
+ h = torch.cat(extra_tokens + [h], dim=1)
250
+
251
+ h = self.ln_pre(h)
252
+ h = self.backbone(h)
253
+ h = self.ln_post(h)
254
+ if len(extra_tokens):
255
+ h = h[:, sum(h.shape[1] for h in extra_tokens) :]
256
+ h = self.output_proj(h)
257
+ return h.permute(0, 2, 1)
258
+
259
+ class Pointe_Diffusion(PointDiffusionTransformer):
260
+ '''
261
+ input: data: data dict
262
+ x: [N x C x T] tensor
263
+ t: [N] tensor
264
+ init:
265
+ n_ctx: int = 1024: context length
266
+ '''
267
+ def __init__(
268
+ self,
269
+ *,
270
+ device = 'cuda',
271
+ dtype = torch.float32,
272
+ encoder = 'miche',
273
+ n_ctx: int = 1024,
274
+ token_cond: bool = True,
275
+ cond_drop_prob: float = 0.1,
276
+ fix_emb: bool = False,
277
+
278
+ **kwargs,
279
+ ):
280
+ super().__init__(dtype=dtype, n_ctx=n_ctx + int(token_cond), **kwargs)
281
+ self.n_ctx = n_ctx
282
+ self.token_cond = token_cond
283
+ # self.proj_transformer = projection_transformer(**kwargs)
284
+ self.encoder_name = encoder
285
+ self.cond_drop_prob = cond_drop_prob
286
+ self.fix_emb = fix_emb
287
+ self.dtype = dtype
288
+ self.inference = False
289
+ def cached_model_kwargs(self, batch_size: int, model_kwargs: Dict[str, Any]) -> Dict[str, Any]:
290
+ with torch.no_grad():
291
+ return dict(embeddings=self.clip(batch_size, **model_kwargs))
292
+
293
+ def inference_mode(self,eps=0.03):
294
+ self.inference = True
295
+
296
+ def forward_func(
297
+ self,
298
+ latent: torch.Tensor,
299
+ data,
300
+ device='cuda',
301
+ downsample = False,
302
+ **kwargs,
303
+ ):
304
+ t = kwargs['timesteps'].to(latent.device)
305
+ x = kwargs['noisy_joints'].to(latent.device)
306
+ assert x.shape[-1] == self.n_ctx, f"x shape: {x.shape}, n_ctx: {self.n_ctx}"
307
+ t_embed = self.time_embed(timestep_embedding(t, self.backbone.width))
308
+
309
+ if self.training:
310
+ mask = torch.rand(size=[len(x)]) >= self.cond_drop_prob
311
+ latent = latent * mask[:,None,None].to(latent.device)
312
+
313
+ latent = [(latent, self.token_cond), (t_embed, self.time_token_cond)]
314
+ return self._forward_with_cond(x, latent)
315
+
316
+ def forward(self, latent, data, device='cuda', downsample = False, **kwargs):
317
+ if self.inference == False:
318
+ return self.forward_func(latent, data, device, downsample, **kwargs)
319
+ else:
320
+ generator=torch.Generator(device='cpu')
321
+ scheduler = DDIMScheduler(100)
322
+ scheduler.set_timesteps(100)
323
+ points_shape = [1, self.n_ctx, 3]
324
+
325
+ points_noise = randn_tensor(points_shape, generator=generator)
326
+ points = points_noise.permute(0, 2, 1).to(latent.device)
327
+ for t in scheduler.timesteps:
328
+ with torch.no_grad():
329
+ time_steps = torch.ones(1, 1, dtype=torch.long) * t
330
+ model_output = self.forward_func(latent, data, noisy_joints=points, timesteps = time_steps)
331
+
332
+ points = scheduler.step(model_output, t, points, generator=generator).prev_sample
333
+ points = points.permute(0, 2, 1).cpu()
334
+ assert points.shape[0] == 1, "Inference mode only supports batch size 1"
335
+ joints = points[0].detach().cpu().numpy()
336
+ clustering = DBSCAN(eps=0.05, min_samples=1).fit(joints)
337
+ cluster_centers = []
338
+ for cluster in set(clustering.labels_):
339
+ cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
340
+ return cluster_centers
341
+
342
+ class Cross_Attention_Diffusion(nn.Module):
343
+ def __init__(self,
344
+ input_channels=3, output_channels=3,
345
+ num_z=16, num_x=1024, z_dim=768, x_dim=512,
346
+ num_blocks=6, num_compute_layers=4, num_heads=8,
347
+ mlp_ratio=4., qkv_bias=True, drop=0., attn_drop=0.,
348
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm,num_latents=16,
349
+ device=torch.device('cuda' if torch.cuda.is_available() else 'cpu'),
350
+ use_projection = True,):
351
+ super().__init__()
352
+ self.use_projection = use_projection
353
+ self.device = device
354
+ self.num_z = num_z
355
+ self.num_x = num_x
356
+ self.z_dim = z_dim
357
+ if use_projection:
358
+ self.proj_transformer = projection_transformer(num_latents=num_latents, width=z_dim, heads=num_heads)
359
+ self.prev_latent = nn.Parameter(torch.zeros(1, self.num_z + num_latents + 1, z_dim))
360
+ self.inference = False
361
+
362
+ self.input_proj = nn.Linear(input_channels, x_dim)
363
+ self.ln_pre = nn.LayerNorm(x_dim)
364
+ self.z_init = nn.Parameter(torch.zeros(1, num_z, z_dim))
365
+
366
+ mlp_hidden_dim = int(z_dim * mlp_ratio)
367
+ self.time_embed = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim)
368
+
369
+ self.latent_mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
370
+ self.ln_latent = nn.LayerNorm(z_dim)
371
+ self.blocks = nn.ModuleList([
372
+ RCW_Block(z_dim, x_dim, num_compute_layers=num_compute_layers,
373
+ num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias,
374
+ drop=drop, attn_drop=attn_drop, drop_path=drop_path,
375
+ act_layer=act_layer, norm_layer=norm_layer)
376
+ for _ in range(num_blocks)
377
+ ])
378
+
379
+ # output blocks
380
+ self.ln_post = nn.LayerNorm(x_dim)
381
+ self.output_proj = nn.Linear(x_dim, output_channels)
382
+
383
+ self.initialize_weights()
384
+
385
+ def initialize_weights(self):
386
+ nn.init.normal_(self.z_init, std=.02)
387
+
388
+ # initialize nn.Linear and nn.LayerNorm
389
+ self.apply(self._init_weights)
390
+
391
+ nn.init.constant_(self.ln_latent.weight, 0)
392
+ nn.init.constant_(self.ln_latent.bias, 0)
393
+
394
+ def _init_weights(self, m):
395
+ if isinstance(m, nn.Linear):
396
+ torch.nn.init.xavier_uniform_(m.weight)
397
+ if isinstance(m, nn.Linear) and m.bias is not None:
398
+ nn.init.constant_(m.bias, 0)
399
+ elif isinstance(m, nn.LayerNorm):
400
+ nn.init.constant_(m.bias, 0)
401
+ nn.init.constant_(m.weight, 1.0)
402
+
403
+ def inference_mode(self,eps=0.03):
404
+ self.inference = True
405
+
406
+ def forward_func(self, latent, data, device='cuda', downsample = False, **kwargs):
407
+ """
408
+ Forward pass of the model.
409
+
410
+ Parameters:
411
+ x: [B, num_x, C_in]
412
+ t: [B]
413
+ cond: [B, num_cond, C_latent]
414
+ prev_latent: [B, num_z + num_cond + 1, C_latent]
415
+
416
+ Returns:
417
+ x_denoised: [B, num_x, C_out]
418
+ z: [B, num_z + num_cond + 1, C_latent]
419
+ """
420
+ t = kwargs['timesteps'].to(latent.device)
421
+ x = kwargs['noisy_joints'].to(latent.device)
422
+ x = x.permute(0, 2, 1)
423
+ B, num_x, _ = x.shape
424
+ if self.use_projection:
425
+ latent = self.proj_transformer(latent)
426
+ assert num_x == self.num_x, f"x shape: {x.shape}, num_x: {self.num_x}"
427
+ # if prev_latent is not None:
428
+ # _, num_z, _ = prev_latent.shape
429
+ # assert num_z == self.num_z + num_cond + 1
430
+ # else:
431
+ # prev_latent = torch.zeros(B, self.num_z + num_cond + 1, self.z_dim).to(x.device)
432
+
433
+ # timestep embedding, [B, 1, z_dim]
434
+ t_embed = self.time_embed(timestep_embedding(t, self.z_dim))
435
+ if t_embed.dim() == 2:
436
+ t_embed = t_embed.unsqueeze(1)
437
+
438
+ # project x -> [B, num_x, C_x]
439
+ x = self.input_proj(x)
440
+ x = self.ln_pre(x)
441
+
442
+ # latent self-conditioning
443
+ z = self.z_init.repeat(B, 1, 1) # [B, num_z, z_dim
444
+ z = torch.cat([z, latent, t_embed], dim=1) # [B, num_z + num_cond + 1, z_dim]
445
+ prev_latent = self.prev_latent + self.latent_mlp(self.prev_latent.detach())
446
+ z = z + (self.ln_latent(prev_latent))
447
+
448
+ # compute
449
+ for blk in self.blocks:
450
+ z, x = blk(z, x)
451
+
452
+ # output proj
453
+ x = self.ln_post(x)
454
+ x_denoised = self.output_proj(x)
455
+ return x_denoised.permute(0, 2, 1)
456
+
457
+ def forward(self, latent, data, device='cuda', downsample = False, **kwargs):
458
+ if self.inference == False:
459
+ return self.forward_func(latent, data, device, downsample, **kwargs)
460
+ else:
461
+ generator=torch.Generator(device='cpu')
462
+ scheduler = DDIMScheduler(100)
463
+ scheduler.set_timesteps(100)
464
+ points_shape = [1, self.num_x, 3]
465
+
466
+ points_noise = randn_tensor(points_shape, generator=generator)
467
+ points = points_noise.permute(0, 2, 1).to(latent.device)
468
+ for t in scheduler.timesteps:
469
+ with torch.no_grad():
470
+ time_steps = torch.ones(1, 1, dtype=torch.long) * t
471
+ time_steps = time_steps.to(latent.device)
472
+ model_output = self.forward_func(latent, data, noisy_joints=points, timesteps = time_steps)
473
+
474
+ points = scheduler.step(model_output, t, points, generator=generator).prev_sample
475
+ points = points.permute(0, 2, 1).cpu()
476
+ assert points.shape[0] == 1, "Inference mode only supports batch size 1"
477
+ joints = points[0].detach().cpu().numpy()
478
+ clustering = DBSCAN(eps=0.05, min_samples=1).fit(joints)
479
+ cluster_centers = []
480
+ for cluster in set(clustering.labels_):
481
+ cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
482
+ return cluster_centers
483
+
Anymate/models/joint.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ThirdParty.michelangelo.models.modules.embedder import FourierEmbedder
4
+ from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock
5
+ from ThirdParty.eg3d.training.networks_stylegan2 import Generator as StyleGAN2Backbone
6
+ from ThirdParty.eg3d.training.networks_stylegan2 import FullyConnectedLayer
7
+ from Anymate.utils.vol_utils import get_co, sample_from_planes, generate_planes
8
+ from einops import repeat
9
+ from sklearn.cluster import DBSCAN
10
+ from Anymate.utils.vol_utils import extract_keypoints
11
+
12
+ class TransformerDecoder(nn.Module):
13
+ def __init__(self,
14
+ num_latents = 96,
15
+ num_kv_latents = 257,
16
+ out_channels = 3,
17
+ width = 768,
18
+ layers = 7,
19
+ device = 'cuda',
20
+ dtype = torch.float32,
21
+ heads = 12,
22
+ init_scale: float = 0.25,
23
+ flash = False,
24
+ use_checkpoint = False,
25
+ qkv_bias = False):
26
+
27
+ super().__init__()
28
+
29
+ self.use_checkpoint = use_checkpoint
30
+ self.num_latents = num_latents
31
+ self.inference = False
32
+ self.eps = 0.03
33
+
34
+ self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
35
+
36
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
37
+ device=device,
38
+ dtype=dtype,
39
+ n_data=num_kv_latents,
40
+ width=width,
41
+ heads=heads,
42
+ init_scale=init_scale,
43
+ qkv_bias=qkv_bias,
44
+ flash=flash
45
+ )
46
+
47
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
48
+ self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
49
+
50
+ def inference_mode(self, eps=0.03, min_samples=1):
51
+ self.inference = True
52
+ self.eps = eps
53
+ self.min_samples = min_samples
54
+
55
+ def forward(self, latents, data=None, device='cuda', downsample=False, dtype=torch.float32):
56
+
57
+ bs = latents.shape[0]
58
+ query = repeat(self.query, "m c -> b m c", b=bs)
59
+ logits = self.cross_attn_decoder(query, latents)
60
+ logits = self.ln_post(logits)
61
+ logits = self.output_proj(logits)
62
+ if self.inference:
63
+ assert logits.shape[0] == 1, "Inference mode only supports batch size 1"
64
+ joints = logits[0].detach().cpu().numpy()
65
+ clustering = DBSCAN(eps=self.eps, min_samples=self.min_samples).fit(joints)
66
+ cluster_centers = []
67
+ for cluster in set(clustering.labels_):
68
+ cluster_centers.append(joints[clustering.labels_ == cluster].mean(axis=0))
69
+ return cluster_centers
70
+ return logits
71
+
72
+
73
+ class ImplicitTransformerDecoder(nn.Module):
74
+
75
+ def __init__(self, *,
76
+ device = 'cuda',
77
+ dtype = torch.float32,
78
+ num_latents = 257,
79
+ out_channels = 1,
80
+ width = 768,
81
+ heads = 12,
82
+ num_freqs: int = 8,
83
+ include_pi: bool = True,
84
+ init_scale: float = 0.25,
85
+ qkv_bias: bool = False,
86
+ flash: bool = False,
87
+ use_checkpoint: bool = False):
88
+
89
+ super().__init__()
90
+
91
+ self.use_checkpoint = use_checkpoint
92
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
93
+ self.inference = False
94
+
95
+ self.query_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
96
+
97
+ self.cross_attn_decoder = ResidualCrossAttentionBlock(
98
+ device=device,
99
+ dtype=dtype,
100
+ n_data=num_latents,
101
+ width=width,
102
+ heads=heads,
103
+ init_scale=init_scale,
104
+ qkv_bias=qkv_bias,
105
+ flash=flash
106
+ )
107
+
108
+ self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
109
+ self.output_proj = nn.Linear(width, out_channels, device=device, dtype=dtype)
110
+
111
+ # self.queries = get_vol().to(device)
112
+
113
+ def inference_mode(self):
114
+ self.inference = True
115
+
116
+ def forward(self, latents: torch.FloatTensor, data=None, device='cuda', downsample=False):
117
+ bs = latents.shape[0]
118
+ # queries = repeat(self.queries, "m c -> b m c", b=bs)
119
+ out = []
120
+ for b in range(bs):
121
+ queries = get_co(data['vox'][b]).to(device).unsqueeze(0)
122
+ if downsample and data['vox'][b].shape[0] > 50000:
123
+ # random sample
124
+ idx = torch.randperm(data['vox'][b].shape[0])[:50000]
125
+ queries = queries[:, idx]
126
+ queries = self.query_proj(self.fourier_embedder(queries))
127
+ x = self.cross_attn_decoder(queries, latents[b:b+1])
128
+ x = self.ln_post(x)
129
+ x = self.output_proj(x)
130
+ if downsample and data['vox'][b].shape[0] > 50000:
131
+ out.append((x.squeeze(0), idx))
132
+ else:
133
+ out.append(x.squeeze(0))
134
+ if self.inference:
135
+ assert len(out) == 1, "Inference mode only supports batch size 1"
136
+ return extract_keypoints(out[0], data['vox'][0])
137
+
138
+ return out
139
+
140
+
141
+ class TriPlaneDecoder(torch.nn.Module):
142
+ def __init__(self,
143
+ z_dim = 768, # Input latent (Z) dimensionality.
144
+ c_dim = 0, # Conditioning label (C) dimensionality.
145
+ w_dim = 768, # Intermediate latent (W) dimensionality.
146
+ # img_resolution, # Output resolution.
147
+ # img_channels, # Number of output color channels.
148
+ # sr_num_fp16_res = 0,
149
+ mapping_kwargs = {'num_layers': 2}, # Arguments for MappingNetwork.
150
+ # rendering_kwargs = {},
151
+ # sr_kwargs = {},
152
+ synthesis_kwargs = {'num_fp16_res': 0, 'conv_clamp': None, 'fused_modconv_default': 'inference_only'}, # Arguments for SynthesisNetwork.
153
+ ):
154
+ super().__init__()
155
+ self.z_dim=z_dim
156
+ self.c_dim=c_dim
157
+ self.w_dim=w_dim
158
+ # self.img_resolution=img_resolution
159
+ # self.img_channels=img_channels
160
+ # self.renderer = ImportanceRenderer()
161
+ # self.ray_sampler = RaySampler()
162
+ self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32*3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs)
163
+ # self.superresolution = dnnlib.util.construct_class_by_name(class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs)
164
+ self.decoder = OSGDecoder(32, {'decoder_output_dim': 0})
165
+ self.inference = False
166
+ # self.neural_rendering_resolution = 64
167
+ # self.rendering_kwargs = rendering_kwargs
168
+
169
+ self._last_planes = None
170
+ self.plane_axes = generate_planes()
171
+
172
+ def mapping(self, z, c=None, truncation_psi=1, truncation_cutoff=None, update_emas=False):
173
+ # if self.rendering_kwargs['c_gen_conditioning_zero']:
174
+ # c = torch.zeros_like(c)
175
+ # return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
176
+ return self.backbone.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
177
+
178
+ def synthesis(self, ws, c=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
179
+ # cam2world_matrix = c[:, :16].view(-1, 4, 4)
180
+ # intrinsics = c[:, 16:25].view(-1, 3, 3)
181
+
182
+ # if neural_rendering_resolution is None:
183
+ # neural_rendering_resolution = self.neural_rendering_resolution
184
+ # else:
185
+ # self.neural_rendering_resolution = neural_rendering_resolution
186
+
187
+ # Create a batch of rays for volume rendering
188
+ # ray_origins, ray_directions = self.ray_sampler(cam2world_matrix, intrinsics, neural_rendering_resolution)
189
+
190
+ # Create triplanes by running StyleGAN backbone
191
+ # N, M, _ = ray_origins.shape
192
+ if use_cached_backbone and self._last_planes is not None:
193
+ planes = self._last_planes
194
+ else:
195
+ planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
196
+ if cache_backbone:
197
+ self._last_planes = planes
198
+
199
+ # Reshape output into three 32-channel planes
200
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
201
+ return planes
202
+
203
+ # Perform volume rendering
204
+ feature_samples, depth_samples, weights_samples = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs) # channels last
205
+
206
+ # Reshape into 'raw' neural-rendered image
207
+ H = W = self.neural_rendering_resolution
208
+ feature_image = feature_samples.permute(0, 2, 1).reshape(N, feature_samples.shape[-1], H, W).contiguous()
209
+ depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W)
210
+
211
+ # Run superresolution to get final image
212
+ rgb_image = feature_image[:, :3]
213
+ sr_image = self.superresolution(rgb_image, feature_image, ws, noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{k:synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode'})
214
+
215
+ return {'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image}
216
+
217
+ def sample(self, coordinates, directions, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
218
+ # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes.
219
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
220
+ planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs)
221
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
222
+ return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
223
+
224
+ def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs):
225
+ # Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z'
226
+ planes = self.backbone.synthesis(ws, update_emas = update_emas, **synthesis_kwargs)
227
+ planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1])
228
+ return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs)
229
+
230
+ def inference_mode(self):
231
+ self.inference = True
232
+
233
+ def forward(self, z, data=None, device='cuda', downsample=False, c=None, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs):
234
+ # Render a batch of generated images.
235
+ assert z.shape[-1] == self.z_dim
236
+ ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas)
237
+ planes = self.synthesis(ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs)
238
+ bs = planes.shape[0]
239
+ logits = []
240
+ for b in range(bs):
241
+ queries = get_co(data['vox'][b]).to(device).unsqueeze(0)
242
+ if downsample and data['vox'][b].shape[0] > 50000:
243
+ # random sample
244
+ idx = torch.randperm(data['vox'][b].shape[0])[:50000]
245
+ queries = queries[:, idx]
246
+ out = sample_from_planes(self.plane_axes.to(device), planes[b:b+1], queries)
247
+ out = self.decoder(out)
248
+ if downsample and data['vox'][b].shape[0] > 50000:
249
+ logits.append((out.squeeze(0), idx))
250
+ else:
251
+ logits.append(out.squeeze(0))
252
+ if self.inference:
253
+ assert len(logits) == 1, "Inference mode only supports batch size 1"
254
+ return extract_keypoints(logits[0], data['vox'][0])
255
+ return logits
256
+
257
+
258
+ class OSGDecoder(torch.nn.Module):
259
+ def __init__(self, n_features, options):
260
+ super().__init__()
261
+ self.hidden_dim = 64
262
+
263
+ self.net = torch.nn.Sequential(
264
+ FullyConnectedLayer(n_features, self.hidden_dim),
265
+ torch.nn.Softplus(),
266
+ FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'])
267
+ )
268
+
269
+ def forward(self, sampled_features, ray_directions=None):
270
+ # Aggregate features
271
+ sampled_features = sampled_features.mean(1)
272
+ x = sampled_features
273
+
274
+ N, M, C = x.shape
275
+ x = x.view(N*M, C)
276
+
277
+ x = self.net(x)
278
+ x = x.view(N, M, -1)
279
+ return x
280
+ rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF
281
+ sigma = x[..., 0:1]
282
+ return {'rgb': rgb, 'sigma': sigma}
Anymate/models/skin.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock, Transformer
4
+ from ThirdParty.michelangelo.models.modules.embedder import components_from_spherical_harmonics, FourierEmbedder
5
+ from einops import repeat, rearrange
6
+
7
+ class AttendjointsDecoder_combine(nn.Module):
8
+ def __init__(self,
9
+ width = 768,
10
+ layers = 2,
11
+ device = 'cuda',
12
+ dtype = torch.float32,
13
+ heads = 12,
14
+ init_scale: float = 0.25,
15
+ flash = False,
16
+ use_checkpoint = False,
17
+ qkv_bias = False,
18
+ num_freqs: int = 8,
19
+ include_pi: bool = True,
20
+ separate = False,
21
+ use_mask = True,
22
+ use_bone = True,
23
+ inference= False):
24
+
25
+ super().__init__()
26
+ self.inference = inference
27
+ self.use_checkpoint = use_checkpoint
28
+ self.separate = separate
29
+ self.use_mask = use_mask
30
+ # self.num_latents = num_latents
31
+
32
+ # self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
33
+
34
+ self.normal_embedder = components_from_spherical_harmonics
35
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
36
+ self.bone_proj = None if not use_bone else nn.Linear(self.fourier_embedder.out_dim * 2, width, device=device, dtype=dtype)
37
+ self.use_bone = use_bone
38
+
39
+ if not self.separate:
40
+ self.co_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
41
+ self.normal_proj = nn.Linear(25, width, device=device, dtype=dtype)
42
+ else:
43
+ self.pc_proj = nn.Linear(self.fourier_embedder.out_dim + 25, width, device=device, dtype=dtype)
44
+
45
+
46
+ # self.proj_attn = nn.Linear(width, width, device=device, dtype=dtype)
47
+
48
+ self.cross_attn = nn.ModuleList([ResidualCrossAttentionBlock(
49
+ device=device,
50
+ dtype=dtype,
51
+ width=width,
52
+ heads=heads,
53
+ init_scale=init_scale,
54
+ qkv_bias=qkv_bias,
55
+ flash=flash,
56
+ ) for _ in range(layers)])
57
+
58
+ self.cross_attn_joint = nn.ModuleList([ResidualCrossAttentionBlock(
59
+ device=device,
60
+ dtype=dtype,
61
+ width=width,
62
+ heads=heads,
63
+ init_scale=init_scale,
64
+ qkv_bias=qkv_bias,
65
+ flash=flash,
66
+ ) for _ in range(layers)])
67
+
68
+ # self.joint_embed_proj = nn.ModuleList([nn.Linear(width, width, device=device, dtype=dtype) for _ in range(layers)])
69
+
70
+
71
+ self.q_proj = nn.Linear(width, width, device=device, dtype=dtype)
72
+ self.k_proj = nn.Linear(width, width, device=device, dtype=dtype)
73
+ self.ln_1 = nn.LayerNorm(width, device=device, dtype=dtype)
74
+ self.ln_2 = nn.LayerNorm(width, device=device, dtype=dtype)
75
+
76
+ # self.last_cross_attn = ResidualCrossAttentionBlock(
77
+ # device=device,
78
+ # dtype=dtype,
79
+ # width=width,
80
+ # heads=heads,
81
+ # init_scale=init_scale,
82
+ # qkv_bias=qkv_bias,
83
+ # flash=flash,
84
+ # )
85
+ # self.mlp = MLP(device=device, dtype=dtype, width=width, init_scale=init_scale)
86
+ # self.output_proj = nn.Linear(width, 1, device=device, dtype=dtype)
87
+
88
+ def forward(self, latents, data=None, device='cuda', downsample=None, dtype=torch.float32):
89
+ joints = data['bones'].to(device) if self.use_bone else data['joints'].to(device)
90
+ max_joints = max(data['bones_num']) if self.use_bone else max(data['joints_num'])
91
+ mask = data['bones_mask'].to(device) if self.use_bone else data['joints_mask']
92
+
93
+ pc = data['vertices'][..., 0:3].to(device) if self.inference else data['points_cloud'][..., 0:3].to(device)
94
+ feats = data['vertices'][..., 3:].to(device) if self.inference else data['points_cloud'][..., 3:].to(device)
95
+
96
+ if downsample and not self.inference:
97
+ # random sample
98
+ idx = torch.randperm(pc.shape[1])[:downsample].to(device)
99
+ pc = pc[:, idx]
100
+ feats = feats[:, idx]
101
+
102
+ # Embed the input data
103
+ co_embeds = self.fourier_embedder(pc)
104
+ if not self.separate:
105
+ co_embeds = self.co_proj(co_embeds)
106
+
107
+ if self.use_bone:
108
+ # joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints*2:2, :3]), self.fourier_embedder(joints[:,1:max_joints*2:2, :3])), dim=-1)
109
+ joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints,:3]), self.fourier_embedder(joints[:,:max_joints, 3:])), dim=-1)
110
+ else:
111
+ joints_fourier = self.fourier_embedder(joints[:,:max_joints, :3])
112
+
113
+ if not self.separate:
114
+ joints_embeds = self.co_proj(joints_fourier) if not self.use_bone else self.bone_proj(joints_fourier)
115
+
116
+ normal_embeds = self.normal_proj(self.normal_embedder(feats)) if not self.separate else self.normal_embedder(feats)
117
+
118
+ if not self.separate:
119
+ pc_embeds = co_embeds + normal_embeds
120
+ else:
121
+ joints_embeds = self.co_proj(joints_fourier.to(dtype)) if not self.use_bone else self.bone_proj(joints_fourier.to(dtype))
122
+ pc_embeds = self.pc_proj(torch.cat([co_embeds.to(dtype), normal_embeds.to(dtype)], dim=-1))
123
+
124
+ pc_num = pc_embeds.shape[-2]
125
+ joints_num = joints_embeds.shape[-2]
126
+ x = torch.cat([pc_embeds, joints_embeds], dim=-2)
127
+ for i, layer in enumerate(self.cross_attn):
128
+
129
+ x = layer(x, latents)
130
+ if self.use_mask:
131
+ x = self.cross_attn_joint[i](x, x[:, pc_num:], mask=mask.to(device))
132
+ else:
133
+ x = self.cross_attn_joint[i](x, x[:, pc_num:])
134
+ pc_embeds, joints_embeds = x.split([pc_num, joints_num], dim=1)
135
+
136
+ logits = torch.einsum('bnc,bmc->bnm', self.k_proj(self.ln_1(pc_embeds)), self.q_proj(self.ln_2(joints_embeds))) # (b, n, m)
137
+
138
+ if self.use_mask:
139
+ logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
140
+
141
+ if downsample and not self.inference:
142
+ return logits, idx
143
+
144
+ return logits
145
+
146
+ class AttendjointsDecoder_multi(nn.Module):
147
+ def __init__(self,
148
+ # num_latents = 64,
149
+ # num_kv_latents = 257,
150
+ # out_channels = 3,
151
+ width = 768,
152
+ layers = 4,
153
+ device = 'cuda',
154
+ dtype = torch.float32,
155
+ heads = 12,
156
+ init_scale: float = 0.25,
157
+ flash = False,
158
+ use_checkpoint = False,
159
+ qkv_bias = False,
160
+ num_freqs: int = 8,
161
+ concat_num: int = 512,
162
+ include_pi: bool = True,
163
+ separate = False,
164
+ use_mask = True,
165
+ inference_with_repeat=False,
166
+ use_bone = True,
167
+ inference = False):
168
+
169
+ super().__init__()
170
+
171
+ self.use_checkpoint = use_checkpoint
172
+ self.use_mask = use_mask
173
+ self.inference_with_repeat = inference_with_repeat
174
+ self.inference = inference
175
+
176
+ self.self_attn = Transformer(
177
+ device=device,
178
+ dtype=dtype,
179
+ n_ctx=-1,
180
+ width=width,
181
+ layers=layers,
182
+ heads=heads,
183
+ init_scale=init_scale,
184
+ qkv_bias=qkv_bias,
185
+ flash=flash,
186
+ use_checkpoint=False,
187
+
188
+ )
189
+ self.concat_number = concat_num
190
+ self.separate = separate
191
+ self.normal_embedder = components_from_spherical_harmonics
192
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
193
+ self.bone_proj = None if not use_bone else nn.Linear(self.fourier_embedder.out_dim * 2, width, device=device, dtype=dtype)
194
+ self.use_bone = use_bone
195
+ if not self.separate:
196
+ self.co_proj = nn.Linear(self.fourier_embedder.out_dim, width, device=device, dtype=dtype)
197
+ self.normal_proj = nn.Linear(25, width, device=device, dtype=dtype)
198
+ else:
199
+ self.pc_proj = nn.Linear(self.fourier_embedder.out_dim + 25, width, device=device, dtype=dtype)
200
+
201
+ # self.proj_attn = nn.Linear(width, width, device=device, dtype=dtype)
202
+
203
+ # self.ln_post = nn.LayerNorm(width, device=device, dtype=dtype)
204
+ self.output_proj_joints = nn.Linear(width, width, device=device, dtype=dtype)
205
+ self.output_proj_points = nn.Linear(width, width, device=device, dtype=dtype)
206
+ self.layer_norm = nn.LayerNorm(width)
207
+
208
+ # def inference(self, latents, data=None,device='cuda', dtype='float32', use_mask=False):
209
+ def inference_mode(self):
210
+ self.inference = True
211
+
212
+ def forward(self, latents, data=None,device='cuda', downsample=None, dtype='float32'):
213
+ joints = data['bones'].to(device) if self.use_bone else data['joints'].to(device)
214
+ max_joints = max(data['bones_num']) if self.use_bone else max(data['joints_num'])
215
+
216
+ pc = data['points_cloud'][..., 0:3].to(device)
217
+ feats = data['points_cloud'][..., 3:].to(device)
218
+
219
+ if downsample:
220
+ # random sample
221
+ idx = torch.randperm(pc.shape[1])[:downsample].to(device)
222
+ pc = pc[:, idx]
223
+ feats = feats[:, idx]
224
+
225
+ bs = pc.shape[1]//self.concat_number
226
+
227
+ # Embed the input data
228
+ if self.use_bone:
229
+ # joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints*2:2, :3]), self.fourier_embedder(joints[:,1:max_joints*2:2, :3])), dim=-1)
230
+ joints_fourier = torch.cat((self.fourier_embedder(joints[:,:max_joints,:3]), self.fourier_embedder(joints[:,:max_joints, 3:])), dim=-1)
231
+ else:
232
+ joints_fourier = self.fourier_embedder(joints[:,:max_joints, :3])
233
+
234
+ if self.separate:
235
+ joints_embeds = self.co_proj(joints_fourier.to(dtype)) if not self.use_bone else self.bone_proj(joints_fourier.to(dtype))
236
+ points_embeds = self.fourier_embedder(pc)
237
+ normal_embeds = self.normal_embedder(feats)
238
+ points = self.pc_proj(torch.cat([points_embeds, normal_embeds], dim=-1))
239
+ else:
240
+ joints_embeds = self.co_proj(joints_fourier) if not self.use_bone else self.bone_proj(joints_fourier)
241
+ co_embeds = self.fourier_embedder(pc)
242
+ co_embeds = self.co_proj(co_embeds)
243
+ # Embed the normals
244
+ normal_embeds = self.normal_embedder(feats)
245
+ normal_embeds = self.normal_proj(normal_embeds) # (b, n, c)
246
+ points = (co_embeds + normal_embeds)
247
+
248
+ repeated_latents = repeat(latents, "b m c -> b n m c", n=bs)
249
+ repeated_joints = repeat(joints_embeds, "b m c -> b n m c", n=bs)
250
+ points = points.reshape( latents.shape[0], bs, self.concat_number, -1)
251
+
252
+ # Concatenate embeddings
253
+ x = torch.cat([repeated_joints, points, repeated_latents], dim=-2) # (b, bs, concat_number+latent_num+joints_num, c)
254
+
255
+ # Pass through self-attention
256
+ if self.use_mask:
257
+ mask = data['bones_mask'].to(device)
258
+ append_size = x.shape[2]-mask.shape[1] # the zero needs to append after mask
259
+ batch_size = mask.shape[0]
260
+ mask_extend = torch.ones((batch_size,append_size)).to(device)
261
+ mask = torch.cat([mask,mask_extend],dim=-1).repeat(bs,1).to(device)
262
+ x = rearrange(x, "b n m c -> (b n) m c")
263
+ x = self.self_attn(x,mask)
264
+ else:
265
+ x = rearrange(x, "b n m c -> (b n) m c")
266
+ x = self.self_attn(x)
267
+ joints, points, _ = x.split([joints_embeds.shape[1],self.concat_number, latents.shape[1]], dim=1)
268
+ joints = self.output_proj_joints(self.layer_norm(joints))
269
+ points = self.output_proj_points(self.layer_norm(points))
270
+
271
+ logits = torch.einsum('bik,bjk->bij', points, joints)
272
+ logits = rearrange(logits, '(b n) m c -> b (n m) c', b=pc.shape[0],n=bs) # (b, n, c)
273
+
274
+ if self.use_mask:
275
+ mask = data['bones_mask'].to(device)
276
+ logits = logits.masked_fill(mask.unsqueeze(1) == 0, -1e8)
277
+
278
+ if self.inference:
279
+ vertices = data['vertice']
280
+ points_cloud = data['points_cloud'][0,..., 0:3].to(device)
281
+ vertices_exp = vertices[0,...,:3] # (batch_size, num_vertices, 1, 3)
282
+ logits = compute_nearest_points(vertices_exp, points_cloud, logits[0], device)
283
+
284
+ if downsample:
285
+ return logits, idx
286
+
287
+ return logits
288
+
289
+ def compute_nearest_points(vertices, points, logits, device, batch_size=1024):
290
+ # vertices: [N, 3]
291
+ # points: [M, 3]
292
+ # logits: [M, K] (K is the number of skinning weights)
293
+
294
+ num_vertices = vertices.shape[0]
295
+ # Initialize the output tensor for skinning weights
296
+ skin_predict = torch.zeros((num_vertices, logits.shape[1]), device=device)
297
+
298
+ # Split vertices into batches
299
+ for i in range(0, num_vertices, batch_size):
300
+
301
+ batch_vertices = vertices[i:i+batch_size] # [batch_size, 3]
302
+ vertices_exp = batch_vertices.unsqueeze(1) # [batch_size, 1, 3]
303
+ points_exp = points.unsqueeze(0) # [1, num_points, 3]
304
+ distances = torch.sum((vertices_exp - points_exp) ** 2, dim=-1) # [batch_size, num_points]
305
+ nearest_idx = torch.argmin(distances, dim=-1) # [batch_size]
306
+ skin_predict_batch = logits[nearest_idx] # [batch_size, K]
307
+ skin_predict[i:i+batch_size] = skin_predict_batch
308
+
309
+ return skin_predict
Anymate/tmp/.gitkeep ADDED
File without changes
Anymate/utils/dataset_utils.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import trimesh
4
+ from ThirdParty.Rignet_utils import binvox_rw
5
+
6
+
7
+ def sparse_to_index(sparse_matrix):
8
+ index = []
9
+ weight = []
10
+ for j in range(len(sparse_matrix)):
11
+ if sparse_matrix[j] > 0:
12
+ index.append(j)
13
+ weight.append(sparse_matrix[j])
14
+
15
+ return index, weight
16
+
17
+ def index_to_sparse(index, weight, shape):
18
+ sparse_matrix = np.zeros([shape[0], shape[1], shape[2]+1])
19
+
20
+ row_indices, col_indices = np.meshgrid(np.arange(sparse_matrix.shape[0]), np.arange(sparse_matrix.shape[1]), indexing='ij')
21
+
22
+ row_indices = np.expand_dims(row_indices, axis=-1)
23
+ col_indices = np.expand_dims(col_indices, axis=-1)
24
+
25
+ sparse_matrix[row_indices, col_indices, index] = weight
26
+
27
+
28
+ return torch.from_numpy(sparse_matrix[:, :, :-1])
29
+
30
+ def index_to_sparse_con(index, shape):
31
+
32
+ sparse_matrix = np.zeros([shape[0], shape[1], shape[2]+1],dtype=np.int8)
33
+ row_indices, col_indices = np.meshgrid(np.arange(sparse_matrix.shape[0]), np.arange(sparse_matrix.shape[1]), indexing='ij')
34
+
35
+ row_indices = np.expand_dims(row_indices, axis=-1)
36
+ col_indices = np.expand_dims(col_indices, axis=-1)
37
+
38
+ sparse_matrix[row_indices, col_indices, index] = 1
39
+
40
+
41
+ return torch.from_numpy(sparse_matrix[:, :, :-1])
42
+
43
+ def create_mask(n, max_len=64):
44
+ mask = torch.zeros(max_len, dtype=torch.bool)
45
+ mask[:n] = 1
46
+ return mask
47
+
48
+ def reduce(vox):
49
+ new_data = np.zeros((vox.dims[0] // 2, vox.dims[1] // 2, vox.dims[2] // 2)).astype(bool)
50
+ new_data = np.logical_or(new_data, vox.data[::2, ::2, ::2])
51
+ new_data = np.logical_or(new_data, vox.data[1::2, ::2, ::2])
52
+ new_data = np.logical_or(new_data, vox.data[::2, 1::2, ::2])
53
+ new_data = np.logical_or(new_data, vox.data[::2, ::2, 1::2])
54
+ new_data = np.logical_or(new_data, vox.data[1::2, 1::2, ::2])
55
+ new_data = np.logical_or(new_data, vox.data[1::2, ::2, 1::2])
56
+ new_data = np.logical_or(new_data, vox.data[::2, 1::2, 1::2])
57
+ new_data = np.logical_or(new_data, vox.data[1::2, 1::2, 1::2])
58
+ # dilate the new voxel
59
+ new_data[:-1, :, :] = np.logical_or(new_data[:-1, :, :], new_data[1:, :, :])
60
+ new_data[:, :-1, :] = np.logical_or(new_data[:, :-1, :], new_data[:, 1:, :])
61
+ new_data[:, :, :-1] = np.logical_or(new_data[:, :, :-1], new_data[:, :, 1:])
62
+ return binvox_rw.Voxels(new_data, new_data.shape, vox.translate, vox.scale, vox.axis_order)
63
+
64
+ def align(vox, y_max):
65
+ new_data = np.zeros(vox.dims).astype(bool)
66
+ ind = np.argwhere(vox.data)
67
+ ind = ind + (np.array(vox.translate) - np.array([-0.5, -0.5 * (1 - y_max), -0.5])) * vox.dims[0]
68
+ # round to the nearest integer
69
+ # ind = np.round(ind).astype(int)
70
+ ind = np.ceil(ind).astype(int)
71
+ # clip to the valid range
72
+ ind = np.clip(ind, 0, vox.dims[0] - 1)
73
+ # new_data[ind[:, 0], ind[:, 1], ind[:, 2]] = True
74
+ return ind
75
+
76
+ def get_skin_direction(joint_idx, data, parent_index, joints_matrix):
77
+ # Get points influenced by this joint (weight > 0)
78
+ weights = index_to_sparse(data['skins_index'].unsqueeze(0), data['skins_weight'].unsqueeze(0), [1, 8192, data['bones_num']])[0][:,joint_idx]
79
+ mask = weights > 0
80
+
81
+ if not torch.any(mask):
82
+ # If no points are influenced, return the opposite direction of its parent
83
+ parent_idx = parent_index[joint_idx].item()
84
+ if parent_idx == joint_idx:
85
+ return torch.tensor([0, 0, 0.001])
86
+ parent_pos = joints_matrix[parent_idx, :3]
87
+ joint_pos = joints_matrix[joint_idx, :3]
88
+ direction = joint_pos - parent_pos
89
+ norm = torch.norm(direction)
90
+ if norm < 1e-8: # Add check for zero norm
91
+ return torch.tensor([0, 0, 0.001])
92
+ normalized_direction = direction / norm
93
+ return normalized_direction * 0.01
94
+
95
+ # Get joint position
96
+ joint_pos = joints_matrix[joint_idx, :3]
97
+
98
+ # Get weighted average direction from joint to influenced points
99
+ points = data['pc'][mask][:,:3]
100
+ point_weights = weights[mask]
101
+
102
+ # Calculate directions from joint to each point
103
+ directions = points - joint_pos
104
+
105
+ # Calculate weighted average direction
106
+ avg_direction = torch.sum(directions * point_weights.unsqueeze(1), dim=0) / torch.sum(point_weights)
107
+ if torch.norm(avg_direction) < 1e-5:
108
+ return torch.tensor([0, 0, 0.001])
109
+ return avg_direction * 1.25
110
+
111
+ def obj2mesh(obj_path):
112
+ # open the obj as txt
113
+ vertices = []
114
+ faces = []
115
+ with open(obj_path, 'r') as f:
116
+ obj = f.readlines()
117
+ for line in obj:
118
+ if line.startswith('v '):
119
+ vertices.append(list(map(float, line.split()[1:])))
120
+ elif line.startswith('f '):
121
+ faces.append(list(map(int, [i.split('/')[0] for i in line.split()[1:]])))
122
+ vertices = np.array(vertices)
123
+ faces = np.array(faces) - 1
124
+ # print(vertices.shape, faces.shape)
125
+
126
+ # create trimesh mesh with given vertices and faces
127
+ mesh = trimesh.Trimesh(vertices, faces, process=False)
128
+ # print(mesh.vertices.shape, mesh.faces.shape)
129
+ return mesh
Anymate/utils/diffusion_encoder.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from typing import Optional
4
+ from einops import repeat
5
+ import math
6
+ from ThirdParty.michelangelo.models.modules.transformer_blocks import ResidualCrossAttentionBlock,Transformer, checkpoint
7
+ from torch.nn import Sequential, Dropout, Linear, ReLU, Parameter, BatchNorm1d
8
+ from typing import List, Optional, Tuple, Union
9
+
10
+ class ShapeAsLatentModule(nn.Module):
11
+ latent_shape: Tuple[int, int]
12
+
13
+ def __init__(self, *args, **kwargs):
14
+ super().__init__()
15
+
16
+ def encode(self, *args, **kwargs):
17
+ raise NotImplementedError
18
+
19
+ def decode(self, *args, **kwargs):
20
+ raise NotImplementedError
21
+
22
+ def query_geometry(self, *args, **kwargs):
23
+ raise NotImplementedError
24
+
25
+ class FourierEmbedder(nn.Module):
26
+
27
+ def __init__(self,
28
+ num_freqs: int = 6,
29
+ logspace: bool = True,
30
+ input_dim: int = 3,
31
+ include_input: bool = True,
32
+ include_pi: bool = True) -> None:
33
+
34
+ """The initialization"""
35
+
36
+ super().__init__()
37
+
38
+ if logspace:
39
+ frequencies = 2.0 ** torch.arange(
40
+ num_freqs,
41
+ dtype=torch.float32
42
+ )
43
+ else:
44
+ frequencies = torch.linspace(
45
+ 1.0,
46
+ 2.0 ** (num_freqs - 1),
47
+ num_freqs,
48
+ dtype=torch.float32
49
+ )
50
+
51
+ if include_pi:
52
+ frequencies *= torch.pi
53
+
54
+ self.register_buffer("frequencies", frequencies, persistent=False)
55
+ self.include_input = include_input
56
+ self.num_freqs = num_freqs
57
+
58
+ self.out_dim = self.get_dims(input_dim)
59
+
60
+ def get_dims(self, input_dim):
61
+ temp = 1 if self.include_input or self.num_freqs == 0 else 0
62
+ out_dim = input_dim * (self.num_freqs * 2 + temp)
63
+
64
+ return out_dim
65
+
66
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
67
+
68
+ if self.num_freqs > 0:
69
+ self.frequencies = self.frequencies.to(x.device)
70
+ embed = (x[..., None].contiguous() * self.frequencies).view(*x.shape[:-1], -1)
71
+
72
+ if self.include_input:
73
+ return torch.cat((x, embed.sin(), embed.cos()), dim=-1)
74
+ else:
75
+ return torch.cat((embed.sin(), embed.cos()), dim=-1)
76
+ else:
77
+ return x
78
+
79
+ def MLP(channels, batch_norm=True):
80
+ if batch_norm:
81
+ return Sequential(*[Sequential(Linear(channels[i - 1], channels[i]), ReLU(), BatchNorm1d(channels[i], momentum=0.1))
82
+ for i in range(1, len(channels))])
83
+ else:
84
+ return Sequential(*[Sequential(Linear(channels[i - 1], channels[i]), ReLU()) for i in range(1, len(channels))])
85
+
86
+ class CrossAttentionEncoder(nn.Module):
87
+
88
+ def __init__(self, *,
89
+ device: Optional[torch.device],
90
+ dtype: Optional[torch.dtype],
91
+ num_latents: int,
92
+ fourier_embedder: FourierEmbedder,
93
+ point_feats: int,
94
+ width: int,
95
+ heads: int,
96
+ layers: int,
97
+ init_scale: float = 0.25,
98
+ qkv_bias: bool = True,
99
+ flash: bool = False,
100
+ use_ln_post: bool = False,
101
+ use_checkpoint: bool = False):
102
+
103
+ super().__init__()
104
+
105
+ self.use_checkpoint = use_checkpoint
106
+ self.num_latents = num_latents
107
+ self.query = nn.Parameter(torch.randn((num_latents, width), device=device, dtype=dtype) * 0.02)
108
+
109
+ self.fourier_embedder = fourier_embedder
110
+ self.input_proj = nn.Linear(self.fourier_embedder.out_dim + point_feats, width, device=device, dtype=dtype)
111
+ self.cross_attn = ResidualCrossAttentionBlock(
112
+ device=device,
113
+ dtype=dtype,
114
+ width=width,
115
+ heads=heads,
116
+ init_scale=init_scale,
117
+ qkv_bias=qkv_bias,
118
+ flash=flash,
119
+ )
120
+
121
+ self.self_attn = Transformer(
122
+ device=device,
123
+ dtype=dtype,
124
+ n_ctx=num_latents,
125
+ width=width,
126
+ layers=layers,
127
+ heads=heads,
128
+ init_scale=init_scale,
129
+ qkv_bias=qkv_bias,
130
+ flash=flash,
131
+ use_checkpoint=False
132
+ )
133
+
134
+ if use_ln_post:
135
+ self.ln_post = nn.LayerNorm(width, dtype=dtype, device=device)
136
+ else:
137
+ self.ln_post = None
138
+
139
+ def _forward(self, pc, feats):
140
+ """
141
+
142
+ Args:
143
+ pc (torch.FloatTensor): [B, N, 3]
144
+ feats (torch.FloatTensor or None): [B, N, C]
145
+
146
+ Returns:
147
+
148
+ """
149
+
150
+ bs = pc.shape[0]
151
+
152
+ data = self.fourier_embedder(pc)
153
+ if feats is not None:
154
+ data = torch.cat([data, feats], dim=-1)
155
+ data = self.input_proj(data)
156
+
157
+ query = repeat(self.query, "m c -> b m c", b=bs)
158
+ latents = self.cross_attn(query, data)
159
+ latents = self.self_attn(latents)
160
+
161
+ if self.ln_post is not None:
162
+ latents = self.ln_post(latents)
163
+
164
+ return latents, pc
165
+
166
+ def forward(self, pc: torch.FloatTensor, feats: Optional[torch.FloatTensor] = None):
167
+ """
168
+
169
+ Args:
170
+ pc (torch.FloatTensor): [B, N, 3]
171
+ feats (torch.FloatTensor or None): [B, N, C]
172
+
173
+ Returns:
174
+ dict
175
+ """
176
+
177
+ return checkpoint(self._forward, (pc, feats), self.parameters(), self.use_checkpoint)
178
+
179
+
180
+
181
+ class TransformerEncoder(ShapeAsLatentModule):
182
+ def __init__(self, *,
183
+ device: Optional[torch.device]='cuda',
184
+ dtype: Optional[torch.dtype],
185
+ num_latents: int = 16,
186
+ point_feats: int = 3,
187
+ embed_dim: int = 64,
188
+ num_freqs: int = 8,
189
+ include_pi: bool = True,
190
+ width: int = 768,
191
+ heads: int = 12,
192
+ num_encoder_layers: int = 8,
193
+ init_scale: float = 0.25,
194
+ qkv_bias: bool = True,
195
+ flash: bool = False,
196
+ use_ln_post: bool = False,
197
+ use_checkpoint: bool = False,
198
+ out_channels: int = 4):
199
+
200
+ super().__init__()
201
+
202
+ self.use_checkpoint = use_checkpoint
203
+
204
+ self.num_latents = num_latents
205
+ self.fourier_embedder = FourierEmbedder(num_freqs=num_freqs, include_pi=include_pi)
206
+
207
+ init_scale = init_scale * math.sqrt(1.0 / width)
208
+ self.encoder = CrossAttentionEncoder(
209
+ device=device,
210
+ dtype=dtype,
211
+ fourier_embedder=self.fourier_embedder,
212
+ num_latents=num_latents,
213
+ point_feats=point_feats,
214
+ width=width,
215
+ heads=heads,
216
+ layers=num_encoder_layers,
217
+ init_scale=init_scale,
218
+ qkv_bias=qkv_bias,
219
+ flash=flash,
220
+ use_ln_post=use_ln_post,
221
+ use_checkpoint=use_checkpoint
222
+ )
223
+ self.width = width
224
+ self.out_channels = out_channels
225
+ self.device = device
226
+
227
+ self.embed_dim = embed_dim
228
+
229
+ def encode(self,data):
230
+ input_points = data['points_cloud'].to(self.device)
231
+ bs = input_points.shape[0]
232
+ pc, feats = input_points[...,:3], input_points[..., 3:]
233
+ latents, _ = self.encoder(pc, feats)
234
+ # print_time('after encoder')
235
+ latents = latents.reshape(bs,-1, self.width)
236
+ return latents
237
+ def encode_pc(self,points_cloud):
238
+ bs = points_cloud.shape[0]
239
+ input_points = points_cloud.to(self.device)
240
+ pc, feats = input_points[...,:3], input_points[..., 3:]
241
+ latents, _ = self.encoder(pc, feats)
242
+
243
+ latents = latents.reshape(bs,-1, self.width)
244
+ return latents
245
+ def forward(self, data):
246
+
247
+ # input_points = torch.from_numpy(np.array(data.points_cloud)).cuda()
248
+ input_points = data['points_cloud'].to(self.device)
249
+ pc, feats = input_points[...,:3], input_points[..., 3:]
250
+ latents, _ = self.encoder(pc, feats)
251
+
252
+ latents = latents.reshape(-1, self.width)
253
+ latents =latents.reshape(-1, self.num_latents, self.out_channels)
254
+ latents[..., :3] = torch.tanh(latents[..., :3])
255
+ latents[..., 3:] = torch.sigmoid(latents[..., 3:])
256
+
257
+
258
+ return latents
Anymate/utils/diffusion_utils.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import numpy as np
3
+ import matplotlib.pyplot as plt
4
+ from mpl_toolkits.mplot3d import Axes3D
5
+ from torchvision.utils import make_grid
6
+ import torch
7
+ from typing import List, Optional, Tuple, Union
8
+ import torch.nn as nn
9
+ import math
10
+ from timm.models.vision_transformer import Mlp, DropPath
11
+
12
+ def my_collate_diff(batch,return_joints_num=128,random=False):
13
+ data = {}
14
+ for key in batch[0]:
15
+ if key=='vox' or key=='name' or key=='joints_num' or key=='skins_index' or key=='skins_weight' or key=='parent_index' or key=='conns' or key=='joints' or key=='bones' or key=='mesh_skins_index' or key=='mesh_skins_weight' or key=='mesh_pc' or key=='mesh_face':
16
+ data[key] = [sample[key] for sample in batch]
17
+ elif key=='pc':
18
+ data['points_cloud'] = torch.stack([sample['pc'] for sample in batch])
19
+ elif key=='skins':
20
+ continue
21
+ elif key=='bones_num':
22
+ data[key] = torch.tensor([sample['bones_num'] for sample in batch])
23
+ else:
24
+ data[key] = torch.stack([sample[key] for sample in batch])
25
+
26
+ if 'joints' in batch[0]:
27
+ padded_joints_matrix = torch.ones(len(data['name']), return_joints_num, 3) * (-3)
28
+ joints_matrix = torch.ones(len(data['name']), 96, 3) * (-3)
29
+ for i in range(len(data['name'])):
30
+ joints_matrix[i, :data['joints_num'][i], :] = data['joints'][i]
31
+ if not random:
32
+ for i in range(len(data['name'])):
33
+ padded_joints_matrix[i] = data['joints'][i].repeat(return_joints_num//data['joints_num'][i]+1,1)[:return_joints_num,:]
34
+ else:
35
+ for i in range(len(data['name'])):
36
+ padded_joints_matrix[i] = data['joints'][i][torch.randint(0, data['joints_num'][i], (return_joints_num,))]
37
+ data['joints_repeat'] = padded_joints_matrix
38
+ data['joints'] = joints_matrix
39
+
40
+ return data
41
+
42
+ def randn_tensor(
43
+ shape: Union[Tuple, List],
44
+ generator: Optional[Union[List["torch.Generator"], "torch.Generator"]] = None,
45
+ device: Optional["torch.device"] = None,
46
+ dtype: Optional["torch.dtype"] = None,
47
+ layout: Optional["torch.layout"] = None,
48
+ ):
49
+ """A helper function to create random tensors on the desired `device` with the desired `dtype`. When
50
+ passing a list of generators, you can seed each batch size individually. If CPU generators are passed, the tensor
51
+ is always created on the CPU.
52
+ """
53
+ # device on which tensor is created defaults to device
54
+ rand_device = device
55
+ batch_size = shape[0]
56
+
57
+ layout = layout or torch.strided
58
+ device = device or torch.device("cpu")
59
+
60
+ if generator is not None:
61
+ gen_device_type = generator.device.type if not isinstance(generator, list) else generator[0].device.type
62
+ if gen_device_type != device.type and gen_device_type == "cpu":
63
+ rand_device = "cpu"
64
+
65
+ elif gen_device_type != device.type and gen_device_type == "cuda":
66
+ raise ValueError(f"Cannot generate a {device} tensor from a generator of type {gen_device_type}.")
67
+
68
+ # make sure generator list of length 1 is treated like a non-list
69
+ if isinstance(generator, list) and len(generator) == 1:
70
+ generator = generator[0]
71
+
72
+ if isinstance(generator, list):
73
+ shape = (1,) + shape[1:]
74
+ latents = [
75
+ torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype, layout=layout)
76
+ for i in range(batch_size)
77
+ ]
78
+ latents = torch.cat(latents, dim=0).to(device)
79
+ else:
80
+ latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype, layout=layout).to(device)
81
+
82
+ return latents
83
+
84
+ def timestep_embedding(timesteps, dim, max_period=10000):
85
+ """
86
+ Create sinusoidal timestep embeddings.
87
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
88
+ These may be fractional.
89
+ :param dim: the dimension of the output.
90
+ :param max_period: controls the minimum frequency of the embeddings.
91
+ :return: an [N x dim] Tensor of positional embeddings.
92
+ """
93
+ half = dim // 2
94
+ freqs = torch.exp(
95
+ -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half
96
+ ).to(device=timesteps.device)
97
+ args = timesteps[:, None].to(timesteps.dtype) * freqs[None]
98
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
99
+ if dim % 2:
100
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
101
+ return embedding
102
+
103
+ class CrossAttention(nn.Module):
104
+ def __init__(
105
+ self,
106
+ dim,
107
+ kv_dim=None,
108
+ num_heads=16,
109
+ qkv_bias=False,
110
+ attn_drop=0.,
111
+ proj_drop=0.,
112
+ ):
113
+ super().__init__()
114
+ self.num_heads = num_heads
115
+ head_dim = dim // num_heads
116
+ self.scale = head_dim ** -0.5
117
+
118
+ kv_dim = dim if not kv_dim else kv_dim
119
+ self.wq = nn.Linear(dim, dim, bias=qkv_bias)
120
+ self.wk = nn.Linear(kv_dim, dim, bias=qkv_bias)
121
+ self.wv = nn.Linear(kv_dim, dim, bias=qkv_bias)
122
+ self.attn_drop_rate = attn_drop
123
+ self.attn_drop = nn.Dropout(self.attn_drop_rate)
124
+ self.proj = nn.Linear(dim, dim)
125
+ self.proj_drop = nn.Dropout(proj_drop)
126
+
127
+ def forward(self, x_q, x_kv):
128
+ B, N_q, C = x_q.shape
129
+ B, N_kv, _ = x_kv.shape
130
+ # [B, N_q, C] -> [B, N_q, H, C/H] -> [B, H, N_q, C/H]
131
+ q = self.wq(x_q).reshape(B, N_q, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
132
+ # [B, N_kv, C] -> [B, N_kv, H, C/H] -> [B, H, N_kv, C/H]
133
+ k = self.wk(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
134
+ # [B, N_kv, C] -> [B, N_kv, H, C/H] -> [B, H, N_kv, C/H]
135
+ v = self.wv(x_kv).reshape(B, N_kv, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
136
+
137
+ # [B, H, N_q, C/H] @ [B, H, C/H, N_kv] -> [B, H, N_q, N_kv]
138
+ attn = (q @ k.transpose(-2, -1)) * self.scale
139
+ attn = attn.softmax(dim=-1)
140
+ attn = self.attn_drop(attn)
141
+
142
+ # [B, H, N_q, N_kv] @ [B, H, N_kv, C/H] -> [B, H, N_q, C/H]
143
+ x = attn @ v
144
+
145
+ # [B, H, N_q, C/H] -> [B, N_q, C]
146
+ x = x.transpose(1, 2).reshape(B, N_q, C)
147
+ x = self.proj(x)
148
+ x = self.proj_drop(x)
149
+ return x
150
+
151
+ class Compute_Block(nn.Module):
152
+
153
+ def __init__(self, z_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
154
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
155
+ super().__init__()
156
+ self.norm_z1 = norm_layer(z_dim)
157
+ self.attn = CrossAttention(
158
+ z_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
159
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
160
+ self.norm_z2 = norm_layer(z_dim)
161
+ mlp_hidden_dim = int(z_dim * mlp_ratio)
162
+ self.mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
163
+
164
+ def forward(self, z):
165
+ zn = self.norm_z1(z)
166
+ z = z + self.drop_path(self.attn(zn, zn))
167
+ z = z + self.drop_path(self.mlp(self.norm_z2(z)))
168
+ return z
169
+
170
+ class Read_Block(nn.Module):
171
+
172
+ def __init__(self, z_dim, x_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
173
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
174
+ super().__init__()
175
+ self.norm_x = norm_layer(x_dim)
176
+ self.norm_z1 = norm_layer(z_dim)
177
+ self.attn = CrossAttention(
178
+ z_dim, x_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
179
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
180
+ self.norm_z2 = norm_layer(z_dim)
181
+ mlp_hidden_dim = int(z_dim * mlp_ratio)
182
+ self.mlp = Mlp(in_features=z_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
183
+
184
+ def forward(self, z, x):
185
+ z = z + self.drop_path(self.attn(self.norm_z1(z), self.norm_x(x)))
186
+ z = z + self.drop_path(self.mlp(self.norm_z2(z)))
187
+ return z
188
+
189
+ class Write_Block(nn.Module):
190
+
191
+ def __init__(self, z_dim, x_dim, num_heads=16, mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
192
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
193
+ super().__init__()
194
+ self.norm_z = norm_layer(z_dim)
195
+ self.norm_x1 = norm_layer(x_dim)
196
+ self.attn = CrossAttention(
197
+ x_dim, z_dim, num_heads=num_heads, qkv_bias=qkv_bias, attn_drop=attn_drop, proj_drop=drop)
198
+ self.drop_path = DropPath(drop_path) if drop_path > 0. else nn.Identity()
199
+ self.norm_x2 = norm_layer(x_dim)
200
+ mlp_hidden_dim = int(x_dim * mlp_ratio)
201
+ self.mlp = Mlp(in_features=x_dim, hidden_features=mlp_hidden_dim, act_layer=act_layer, drop=drop)
202
+
203
+ def forward(self, z, x):
204
+ x = x + self.drop_path(self.attn(self.norm_x1(x), self.norm_z(z)))
205
+ x = x + self.drop_path(self.mlp(self.norm_x2(x)))
206
+ return x
207
+
208
+ class RCW_Block(nn.Module):
209
+
210
+ def __init__(self, z_dim, x_dim, num_compute_layers=4, num_heads=16,
211
+ mlp_ratio=4., qkv_bias=False, drop=0., attn_drop=0.,
212
+ drop_path=0., act_layer=nn.GELU, norm_layer=nn.LayerNorm):
213
+ super().__init__()
214
+ self.read = Read_Block(z_dim, x_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
215
+ attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer)
216
+ self.write = Write_Block(z_dim, x_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
217
+ attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer)
218
+ self.compute = nn.ModuleList([
219
+ Compute_Block(z_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop,
220
+ attn_drop=attn_drop, drop_path=drop_path, act_layer=act_layer, norm_layer=norm_layer)
221
+ for _ in range(num_compute_layers)
222
+ ])
223
+
224
+ def forward(self, z, x):
225
+ z = self.read(z, x)
226
+ for layer in self.compute:
227
+ z = layer(z)
228
+ x = self.write(z, x)
229
+ return z, x
230
+
231
+ def pairwise_distances(x, y):
232
+ #Input: x is a Nxd matrix
233
+ # y is an optional Mxd matirx
234
+ #Output: dist is a NxM matrix where dist[i,j] is the square norm between x[i,:] and y[j,:]
235
+ # if y is not given then use 'y=x'.
236
+ #i.e. dist[i,j] = ||x[i,:]-y[j,:]||^2
237
+ x_norm = (x ** 2).sum(1).view(-1, 1)
238
+ y_t = torch.transpose(y, 0, 1)
239
+ y_norm = (y ** 2).sum(1).view(1, -1)
240
+ dist = x_norm + y_norm - 2.0 * torch.mm(x, y_t)
241
+ return torch.clamp(dist, 0.0, np.inf)
242
+
243
+ def meanshift_cluster(pts_in, bandwidth, weights=None, max_iter=20):
244
+ """
245
+ Meanshift clustering
246
+ :param pts_in: input points
247
+ :param bandwidth: bandwidth
248
+ :param weights: weights per pts indicting its importance in the clustering
249
+ :return: points after clustering
250
+ """
251
+ diff = 1e10
252
+ num_iter = 1
253
+ while diff > 1e-3 and num_iter < max_iter:
254
+ Y = np.sum(((pts_in[np.newaxis, ...] - pts_in[:, np.newaxis, :]) ** 2), axis=2)
255
+ K = np.maximum(bandwidth**2 - Y, np.zeros(Y.shape))
256
+ if weights is not None:
257
+ K = K * weights
258
+ row_sums = K.sum(axis=0, keepdims=True)
259
+ P = K / (row_sums + 1e-10)
260
+ P = P.transpose()
261
+ pts_in_prim = 0.3 * (np.matmul(P, pts_in) - pts_in) + pts_in
262
+ diff = np.sqrt(np.sum((pts_in_prim - pts_in)**2))
263
+ pts_in = pts_in_prim
264
+ num_iter += 1
265
+ return pts_in
266
+
267
+ def nms_meanshift(pts_in, density, bandwidth):
268
+ """
269
+ NMS to extract modes after meanshift. Code refers to sci-kit-learn.
270
+ :param pts_in: input points
271
+ :param density: density at each point
272
+ :param bandwidth: bandwidth used in meanshift. Used here as neighbor region for NMS
273
+ :return: extracted clusters.
274
+ """
275
+ Y = np.sum(((pts_in[np.newaxis, ...] - pts_in[:, np.newaxis, :]) ** 2), axis=2)
276
+ sorted_ids = np.argsort(density)[::-1]
277
+ unique = np.ones(len(sorted_ids), dtype=bool)
278
+ dist = np.sqrt(Y)
279
+ for i in sorted_ids:
280
+ if unique[i]:
281
+ neighbor_idxs = np.argwhere(dist[:, i] <= bandwidth)
282
+ unique[neighbor_idxs.squeeze()] = 0
283
+ unique[i] = 1 # leave the current point as unique
284
+ pts_in = pts_in[unique]
285
+ return pts_in
286
+
287
+ def get_predictions(y_pred_np, attn_pred_np=None,bandwidth=0.05, threshold=0.001):
288
+ """
289
+ get the final predictions
290
+ :param pts: input points
291
+ :param weights: weight per point during clustering
292
+ :return: clustered points
293
+ """
294
+ # if attn_pred_np is None:
295
+ # attn_pred_np = np.ones(y_pred_np.shape[0])
296
+ y_pred_np = meanshift_cluster(y_pred_np, bandwidth, attn_pred_np, max_iter=40)
297
+
298
+
299
+ Y_dist = np.sum(((y_pred_np[np.newaxis, ...] - y_pred_np[:, np.newaxis, :]) ** 2), axis=2)
300
+ density = np.maximum(bandwidth ** 2 - Y_dist, np.zeros(Y_dist.shape))
301
+ density = np.sum(density, axis=0)
302
+ density_sum = np.sum(density)
303
+ y_pred_np = y_pred_np[density / density_sum > threshold]
304
+
305
+ density = density[density / density_sum > threshold]
306
+ pred_joints = nms_meanshift(y_pred_np, density, bandwidth)
307
+ return pred_joints
308
+
309
+
310
+ if __name__ == '__main__':
311
+ points_cloud = np.ones((100, 3))
312
+ predict_out = get_predictions(points_cloud, bandwidth=0.05, threshold=0.001)
313
+ print(predict_out.shape)
314
+
Anymate/utils/eval_utils.py ADDED
@@ -0,0 +1,225 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from tqdm import tqdm
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import point_cloud_utils as pcu
6
+ from Anymate.utils.loss_utils import chamfer_distance_with_average, cross_entropy_with_probs_batch, cos_loss, cos_loss_clamp
7
+ from ThirdParty.Rignet_utils.utils import get_skel
8
+ from ThirdParty.Rignet_utils.Rignet_loss import edit_dist, chamfer_dist, joint2bone_chamfer_dist, bone2bone_chamfer_dist
9
+ from scipy.optimize import linear_sum_assignment
10
+
11
+ def evaluate_joint(joints, joints_gt, threshold=1e-1):
12
+ """
13
+ joints: list of predicted joints: tensor of shape (n,joints_num,3)
14
+ joints_gt: list of ground truth joints : tensor of shape (n,joints_num,3)
15
+ """
16
+ chamfer_loss_all = 0
17
+ emd_loss_all = 0
18
+ precision = 0
19
+ recall = 0
20
+ count = 0
21
+
22
+ for i in tqdm(range(len(joints))):
23
+ joint_predict = joints[i].cpu()
24
+ joint_gt = joints_gt[i].cpu()
25
+ distance_matrix = torch.cdist(joint_gt, joint_predict) # (n_gt, n_predict)
26
+ n_gt,n_predict = distance_matrix.shape
27
+ min_distance_pred = torch.min(distance_matrix, dim=0)
28
+ min_distance_gt = torch.min(distance_matrix, dim=1)
29
+ precision += torch.sum(min_distance_pred.values < threshold).item()/n_predict
30
+ recall += torch.sum(min_distance_gt.values < threshold).item()/n_gt
31
+
32
+ chamfer_loss_all += chamfer_distance_with_average(joint_predict.unsqueeze(0), joint_gt.unsqueeze(0))
33
+ joint_predict = joint_predict.numpy().astype(np.float64)
34
+ joint_gt = joint_gt.numpy().astype(np.float64)
35
+ emd,_ = pcu.earth_movers_distance(joint_predict, joint_gt)
36
+ emd_loss_all += emd
37
+
38
+ count += 1
39
+
40
+ print('------------------------------------')
41
+ print('Evaluation results for joint:')
42
+ print('chamfer_loss:', chamfer_loss_all/count)
43
+ print('emd_loss:', emd_loss_all/count)
44
+ print('precision:', precision/count)
45
+ print('recall:', recall/count)
46
+ print('count:', count)
47
+ print('------------------------------------')
48
+ return chamfer_loss_all/count, emd_loss_all/count, precision/count, recall/count
49
+
50
+ def evaluate_connectivity(conns, conns_gt, joints_gt, vox_list):
51
+
52
+ """
53
+ conns: list of predicted connections probability: tensor of shape (n,joints_num,joints_num)
54
+ conns_gt: list of ground truth connections: tensor of shape (n,joints_num,joints_num)
55
+ """
56
+
57
+ precision_all = 0
58
+ recall_all = 0
59
+ cross_entropy_all = 0
60
+ bone2bone_dist_con = 0
61
+ count = 0
62
+ for i in tqdm(range(len(conns))):
63
+
64
+ conn_predict = conns[i].cpu().numpy()
65
+ conn_gt = conns_gt[i].cpu().numpy()
66
+ joints = joints_gt[i].cpu().numpy()
67
+ vox = vox_list[i]
68
+
69
+ cross_entropy_all += cross_entropy_with_probs_batch(torch.from_numpy(conn_predict).unsqueeze(0), torch.from_numpy(conn_gt).unsqueeze(0), reduction='mean')
70
+ # consider to add tree edit distance (need joint and vox information)
71
+ pred_skel, parent_matrix = get_skel(joints, conn_predict, vox=vox)
72
+ gt_skel, parent_matrix = get_skel(joints, conn_gt, vox=vox)
73
+ bone2bone_dist_con += bone2bone_chamfer_dist(pred_skel, gt_skel)
74
+
75
+ conn_predict = np.argmax(conn_predict, axis=1)
76
+ conn_gt = np.argmax(conn_gt, axis=1)
77
+ connection_matrix_pre = torch.zeros((len(conn_predict),len(conn_predict)))
78
+ connection_matrix_gt = torch.zeros((len(conn_predict),len(conn_predict)))
79
+
80
+ for i in range(len(conn_predict)):
81
+ connection_matrix_pre[i][conn_predict[i]] = 1
82
+ connection_matrix_pre[conn_predict[i]][i] = 1
83
+ connection_matrix_gt[i][conn_gt[i]] = 1
84
+ connection_matrix_gt[conn_gt[i]][i] = 1
85
+
86
+ TP = 0
87
+ FP = 0
88
+ FN = 0
89
+ FP = 0
90
+
91
+ for i in range(len(conn_predict)):
92
+ if connection_matrix_gt[i][conn_predict[i]] == 1:
93
+ TP += 1
94
+ if connection_matrix_gt[i][conn_predict[i]] == 0:
95
+ FP += 1
96
+ if connection_matrix_pre[i][conn_gt[i]] == 0:
97
+ FN += 1
98
+
99
+ precision = TP/(TP+FP)
100
+ recall = TP/(TP+FN)
101
+
102
+ precision_all += precision
103
+ recall_all += recall
104
+ count+=1
105
+ print('------------------------------------')
106
+ print('Evaluation results for connectivity:')
107
+ print('precision:',precision_all/count)
108
+ print('recall:',recall_all/count)
109
+ print('cross_entropy:',cross_entropy_all/count)
110
+ print('bone2bone_dist_con:',bone2bone_dist_con/count)
111
+ print('count:',count)
112
+ print('------------------------------------')
113
+ return precision_all/count, recall_all/count
114
+
115
+ def evaluate_skinning(skins, skins_gt, threshold=5e-2):
116
+ """
117
+ skins: list of predicted skinning weights: tensor of shape (n,vertices_num, bones_num)
118
+ skins_gt: list of ground truth skinning weights: tensor of shape (n,vertices_num, bones_num)
119
+ """
120
+ cs_loss = 0
121
+ ce_loss = 0
122
+ cs_loss_clamp = 0
123
+ count = 0
124
+ L1_loss = 0
125
+ precision = 0
126
+ recall = 0
127
+ mean_l1_dist = 0
128
+
129
+ for i in tqdm(range(len(skins))):
130
+ skin_predict = skins[i].cpu().unsqueeze(0)
131
+ skin_gt = skins_gt[i].cpu().unsqueeze(0)
132
+
133
+ precision_one = 0
134
+ recall_one = 0
135
+
136
+ ce_loss += cross_entropy_with_probs_batch(skin_predict, skin_gt)
137
+ cs_loss += cos_loss(skin_predict, skin_gt)
138
+ cs_loss_clamp += cos_loss_clamp(skin_predict, skin_gt)
139
+ L1_loss += F.l1_loss(skin_predict, skin_gt)
140
+ skin_predict = skin_predict[0].cpu().detach().numpy()
141
+ skin_gt = skin_gt[0].cpu().detach().numpy()
142
+ mean_l1_dist += np.sum(np.abs(skin_predict - skin_gt )) / len(skin_predict)
143
+
144
+ for i in range(len(skin_predict)):
145
+ influencial_bone_predict = skin_predict[i] >=threshold
146
+ influencial_bone_gt = skin_gt[i] >=threshold
147
+ influencial_bone_correct = influencial_bone_predict*influencial_bone_gt
148
+
149
+ if np.sum(influencial_bone_predict)==0 or np.sum(influencial_bone_gt)==0:
150
+ continue
151
+ precision_one += np.sum(influencial_bone_correct)/np.sum(influencial_bone_predict)
152
+ recall_one += np.sum(influencial_bone_correct)/np.sum(influencial_bone_gt)
153
+
154
+ precision += precision_one/len(skin_predict)
155
+ recall += recall_one/len(skin_predict)
156
+ count +=1
157
+
158
+ print('------------------------------------')
159
+ print('Evaluation results for skinning:')
160
+ print('cos loss: ', cs_loss/count)
161
+ print('ce loss: ', ce_loss/count)
162
+ print('cs_loss_clamp: ', cs_loss_clamp/count)
163
+ print('L1 loss: ', L1_loss/count)
164
+ print('mean_l1_dist: ', mean_l1_dist/count)
165
+ print('precision: ', precision/count)
166
+ print('recall: ', recall/count)
167
+ print('count: ', count)
168
+ print('------------------------------------')
169
+
170
+ def evaluate_skeleton(joints,joints_gt,conns,conns_gt,vox_list,fs_threshold=0.2):
171
+
172
+ """
173
+ joints: list of predicted joints: tensor of shape (n,joints_num,3)
174
+ joints_gt: list of ground truth joints : tensor of shape (n,joints_num,3)
175
+ conns: list of predicted connections probability: tensor of shape (n,joints_num,joints_num)
176
+ conns_gt: list of ground truth connections: tensor of shape (n,joints_num,joints_num)
177
+ vox_list: list of voxel: (n,88,88,88)
178
+ """
179
+
180
+ data_count = 0
181
+ chamfer_score = 0
182
+ j2b_chamfer_joint = 0
183
+ bone2bone_dist_joint = 0
184
+ edit_distance_joint = 0
185
+ joint_IoU_total = 0
186
+ joint_precision_total = 0
187
+ joint_recall_total = 0
188
+
189
+ for i in tqdm(range(len(joints))):
190
+ joint_predict = joints[i].cpu().numpy()
191
+ joint_gt = joints_gt[i].cpu().numpy()
192
+ conn_predict = conns[i].cpu().numpy()
193
+ conn_gt = conns_gt[i].cpu().numpy()
194
+ vox = vox_list[i]
195
+
196
+ # add shape diameter after we have vertex and faces
197
+ # shape_diameter = get_shape_diameter(mesh, points, parent_index[:,0])
198
+
199
+ dist_matrix = np.sqrt(np.sum((joint_predict[np.newaxis, ...] - joint_gt[:, np.newaxis, :]) ** 2, axis=2))
200
+ row_ind, col_ind = linear_sum_assignment(dist_matrix)
201
+ # fs_threshold = shape_diameter[row_ind]
202
+ joint_IoU = 2 * np.sum(dist_matrix[row_ind, col_ind] < fs_threshold) / (len(joint_predict) + len(joint_gt))
203
+ joint_IoU_total += joint_IoU
204
+ joint_precision = np.sum(dist_matrix[row_ind, col_ind] < fs_threshold) / len(joint_predict)
205
+ joint_precision_total += joint_precision
206
+ joint_recall = np.sum(dist_matrix[row_ind, col_ind] < fs_threshold) / len(joint_gt)
207
+ joint_recall_total += joint_recall
208
+
209
+ pred_skel_joint,parent_matrix = get_skel(joint_predict,conn_predict,vox=vox)
210
+ gt_skel, parent_matrix = get_skel(joint_gt,conn_gt,vox=vox)
211
+ chamfer_score += chamfer_dist(joint_predict, joint_gt)
212
+ j2b_chamfer_joint += joint2bone_chamfer_dist(pred_skel_joint, gt_skel)
213
+ bone2bone_dist_joint += bone2bone_chamfer_dist(pred_skel_joint, gt_skel)
214
+ edit_distance_joint += edit_dist(pred_skel_joint, gt_skel)
215
+ data_count+=1
216
+
217
+ print('------------------------------------')
218
+ print('Evaluation results for skeleton:')
219
+ print('chamfer_score:', chamfer_score/data_count)
220
+ print('j2b_chamfer_joint:', j2b_chamfer_joint/data_count)
221
+ print('bone2bone_dist_joint:', bone2bone_dist_joint/data_count)
222
+ print('joint_IoU:', joint_IoU_total/data_count)
223
+ print('joint_precision:', joint_precision_total/data_count)
224
+ print('joint_recall:', joint_recall_total/data_count)
225
+ print('------------------------------------')
Anymate/utils/loss_utils.py ADDED
@@ -0,0 +1,56 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+ def chamfer_distance_with_average(p1, p2):
5
+
6
+ '''
7
+ Calculate Chamfer Distance between two point sets
8
+ :param p1: size[1, N, D]
9
+ :param p2: size[1, M, D]
10
+ :param debug: whether need to output debug info
11
+ :return: sum of Chamfer Distance of two point sets
12
+ '''
13
+
14
+ assert p1.size(0) == 1 and p2.size(0) == 1
15
+ assert p1.size(2) == p2.size(2)
16
+ p1 = p1.repeat(p2.size(1), 1, 1)
17
+ p1 = p1.transpose(0, 1)
18
+ p2 = p2.repeat(p1.size(0), 1, 1)
19
+ dist = torch.add(p1, torch.neg(p2))
20
+ dist_norm = torch.norm(dist, 2, dim=2)
21
+ dist1 = torch.min(dist_norm, dim=1)[0]
22
+ dist2 = torch.min(dist_norm, dim=0)[0]
23
+ loss = 0.5 * ((torch.mean(dist1)) + (torch.mean(dist2)))
24
+ return loss
25
+
26
+ def cross_entropy_with_probs_batch(input, target, weight=None, reduction="mean"): # tested, same as nn.CrossEntropyLoss at dim=1, CE can be negative
27
+ # input_logsoftmax = F.log_softmax(input, dim=2)
28
+ input_logsoftmax = torch.log(input+1e-6)
29
+ cum_losses = -target * input_logsoftmax
30
+ if weight is not None:
31
+ cum_losses = cum_losses * weight.unsqueeze(1) # Broadcasting the weight
32
+
33
+ if reduction == "none":
34
+ return cum_losses
35
+ elif reduction == "mean":
36
+ return cum_losses.sum(dim=2).mean(dim=1).mean(dim=0)
37
+ elif reduction == "sum":
38
+ return cum_losses.sum(dim=2).sum(dim=1).mean(dim=0)
39
+ else:
40
+ raise ValueError("Keyword 'reduction' must be one of ['none', 'mean', 'sum']")
41
+
42
+ def cos_loss(input, target):
43
+ # input = F.softmax(input, dim=-1)
44
+ cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
45
+ similarity = cos(input, target)
46
+ loss = 1 - similarity.mean()
47
+ return loss
48
+
49
+ def cos_loss_clamp(input, target):
50
+ # input = F.softmax(input, dim=-1)*(1 + 2*0.001) - 0.001
51
+ input = input*(1 + 2*0.001) - 0.001
52
+ input = torch.clamp(input, 0, 1)
53
+ cos = nn.CosineSimilarity(dim=-1, eps=1e-6)
54
+ similarity = cos(input, target)
55
+ loss = 1 - similarity.mean()
56
+ return loss
Anymate/utils/render_utils.py ADDED
@@ -0,0 +1,1169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import bpy
2
+ import numpy as np
3
+ from mathutils import Vector, Matrix
4
+ from tqdm import tqdm
5
+ import glob
6
+ import os
7
+ import torch
8
+ from PIL import Image
9
+ import matplotlib.pyplot as plt
10
+ cmap = plt.get_cmap('viridis')
11
+ import torch
12
+ import torchvision.io as io
13
+ import cv2
14
+ import trimesh
15
+
16
+ def get_data(ids, root, animate=False, shift_rig=True, id2=None, rignet=False):
17
+ dataset= torch.load('/data2/aod/testJointDataSet_9.pt')
18
+ joints = []
19
+ conns = []
20
+ skins = []
21
+
22
+ for id in ids:
23
+ if id2 is None:
24
+ for data in dataset:
25
+ if id in data['name']:
26
+ print(data['name'])
27
+ break
28
+ else:
29
+ for data in dataset:
30
+ if id2 in data['name']:
31
+ print(data['name'])
32
+ break
33
+
34
+ joint = torch.tensor(torch.load(root + '/joints/' + id + '.pt')).cpu()
35
+ if shift_rig and id2 is None:
36
+ y_max = data['points_cloud'][:,1].max()
37
+ joint = joint/2 + torch.tensor([0,y_max/2,0])
38
+ temp = joint[:, 1].clone()
39
+ joint[:, 1] = -joint[:, 2]
40
+ joint[:, 2] = temp
41
+
42
+ conn = torch.tensor(torch.load(root + '/connectivity/' + id + '.pt')).long()
43
+ if not animate:
44
+ skin = torch.load(root + '/skinning/' + id + '.pt')
45
+ if rignet:
46
+ skins.append(skin[0])
47
+ elif id2 is None:
48
+ skins.append(skin[0].softmax(dim=-1).cpu().numpy())
49
+ else:
50
+ skins.append(skin)
51
+
52
+ joints.append(joint)
53
+ conns.append(conn)
54
+
55
+ return joints, conns, skins
56
+
57
+ def index_to_sparse(index, weight, shape):
58
+ sparse_matrix = np.zeros([shape[0], shape[1], shape[2]+1])
59
+
60
+ row_indices, col_indices = np.meshgrid(np.arange(sparse_matrix.shape[0]), np.arange(sparse_matrix.shape[1]), indexing='ij')
61
+
62
+ row_indices = np.expand_dims(row_indices, axis=-1)
63
+ col_indices = np.expand_dims(col_indices, axis=-1)
64
+
65
+ sparse_matrix[row_indices, col_indices, index] = weight
66
+
67
+
68
+ return torch.from_numpy(sparse_matrix[:, :, :-1])
69
+
70
+ def get_gt(ids, root):
71
+ dataset= torch.load('/data2/aod/testJointDataSet_9.pt')
72
+ joints = []
73
+ conns = []
74
+ skins = []
75
+
76
+ for id in ids:
77
+ for data in dataset:
78
+ if id in data['name']:
79
+ print(data['name'])
80
+ break
81
+
82
+ joint = data['joints_matrix'][:data['joints_num'], :3]
83
+ y_max = data['points_cloud'][:,1].max()
84
+ joint = joint/2 + torch.tensor([0,y_max/2,0])
85
+ temp = joint[:, 1].clone()
86
+ joint[:, 1] = -joint[:, 2]
87
+ joint[:, 2] = temp
88
+
89
+ conn = data['parent_index'][:data['joints_num']].long().unsqueeze(1)
90
+
91
+ skin = index_to_sparse(data['skin_index'].unsqueeze(0), data['skin_weight'].unsqueeze(0), [1, 8192, data['joints_num']])
92
+
93
+ joints.append(joint)
94
+ conns.append(conn)
95
+ skins.append(skin[0])
96
+
97
+ return joints, conns, skins
98
+
99
+ def empty():
100
+ bpy.ops.wm.read_homefile(use_empty=True)
101
+ # Delete all mesh objects from the scene
102
+ # for obj in bpy.context.scene.objects:
103
+ # bpy.data.objects.remove(obj, do_unlink=True)
104
+
105
+ def add_mesh(filepath, co=None, tex=False, color=(0.5, 0.5, 0.5, 1)):
106
+ bpy.ops.wm.obj_import(filepath=filepath)
107
+ obj = bpy.context.object
108
+
109
+ if not tex:
110
+ # give the mesh a material
111
+ bpy.context.view_layer.objects.active = obj
112
+ bpy.ops.object.shade_smooth()
113
+ bpy.ops.object.mode_set(mode='EDIT')
114
+ bpy.ops.mesh.select_all(action='SELECT')
115
+ bpy.ops.mesh.normals_make_consistent(inside=False)
116
+ bpy.ops.object.mode_set(mode='OBJECT')
117
+ mat = bpy.data.materials.new(name='mat')
118
+ obj.data.materials.clear()
119
+ obj.data.materials.append(mat)
120
+ mat.use_nodes = True
121
+ mat.node_tree.nodes.clear()
122
+ bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
123
+ output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
124
+ mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
125
+ mat.node_tree.nodes['Principled BSDF'].inputs['Roughness'].default_value = 0.8
126
+ # mat.node_tree.nodes['Principled BSDF'].inputs['Specular'].default_value = 0.5
127
+ # mat.node_tree.nodes['Principled BSDF'].inputs['Metallic'].default_value = 0.5
128
+ mat.node_tree.nodes['Principled BSDF'].inputs['Base Color'].default_value = color
129
+ if co is not None:
130
+ obj.parent = co
131
+
132
+ def create_sphere(location, size=0.01, color=(1.0, 0.0, 0.0, 1.0), reduced=False):
133
+ if reduced:
134
+ bpy.ops.mesh.primitive_uv_sphere_add(radius=size, location=location, segments=8, ring_count=4)
135
+ else:
136
+ bpy.ops.mesh.primitive_uv_sphere_add(radius=size, location=location)
137
+ sphere = bpy.context.active_object
138
+
139
+ material_name = f"ColorMaterial_{color}"
140
+ material = bpy.data.materials.get(material_name)
141
+
142
+ if not material:
143
+ material = bpy.data.materials.new(name=material_name)
144
+ material.use_nodes = True
145
+ material.node_tree.nodes.clear()
146
+ bsdf = material.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
147
+ output = material.node_tree.nodes.new('ShaderNodeOutputMaterial')
148
+ material.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
149
+ material.node_tree.nodes['Principled BSDF'].inputs['Base Color'].default_value = color
150
+
151
+ sphere.data.materials.append(material)
152
+
153
+ return sphere
154
+
155
+ def add_co(location=(0,0,0), rotation=(0,0,0), scale=(1,1,1)):
156
+ co = bpy.data.objects.new("CoordinateSystem", None)
157
+ bpy.context.collection.objects.link(co)
158
+ bpy.context.view_layer.objects.active = co
159
+ co.empty_display_size = 0.1
160
+ co.empty_display_type = 'ARROWS'
161
+ co.location = location
162
+ co.rotation_euler = rotation
163
+ co.scale = scale
164
+
165
+ return co
166
+
167
+ def add_joint(joints_matrix, co=None):
168
+
169
+ for i, joint in enumerate(joints_matrix):
170
+ sphere = create_sphere((joint[0], joint[1], joint[2]), size=0.01)
171
+ if co is not None:
172
+ sphere.parent = co
173
+
174
+ def create_blue_cone(base_point, apex_point, radius=0.1):
175
+ # Calculate the radius and length of the cone
176
+ direction = apex_point - base_point
177
+ length = direction.length
178
+
179
+ # Create cone mesh
180
+ bpy.ops.mesh.primitive_cone_add(vertices=32, radius1=radius, depth=length, location=(base_point + direction * 0.5))
181
+ cone = bpy.context.active_object
182
+
183
+ # Create or get the blue material
184
+ blue_material = bpy.data.materials.get("BlueMaterial")
185
+ if not blue_material:
186
+ blue_material = bpy.data.materials.new(name="BlueMaterial")
187
+ blue_material.use_nodes = True
188
+ blue_material.node_tree.nodes.clear()
189
+ bsdf = blue_material.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
190
+ output = blue_material.node_tree.nodes.new('ShaderNodeOutputMaterial')
191
+ blue_material.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
192
+ blue_material.node_tree.nodes['Principled BSDF'].inputs['Base Color'].default_value = (0.0, 0.0, 1.0, 1.0)
193
+
194
+ cone.data.materials.append(blue_material)
195
+
196
+ # Set the cone's orientation
197
+ cone.rotation_euler = direction.to_track_quat('Z', 'Y').to_euler()
198
+
199
+ return cone
200
+
201
+ def add_conn(con_index, joints_matrix, co=None):
202
+ for i, parent in enumerate(con_index):
203
+ parent = parent.item()
204
+ if parent != i:
205
+ parent_co = Vector((joints_matrix[parent][0], joints_matrix[parent][1], joints_matrix[parent][2]))
206
+ position = Vector((joints_matrix[i][0], joints_matrix[i][1], joints_matrix[i][2]))
207
+ cone = create_blue_cone(parent_co, position, radius=0.008)
208
+ if co is not None:
209
+ cone.parent = co
210
+
211
+ def merge_images(img1, img2, output_path, alpha=1):
212
+ image_mesh = Image.open(img1)
213
+ image_rig = Image.open(img2)
214
+
215
+ if alpha == 1:
216
+ image_mesh.paste(image_rig, (0, 0), image_rig)
217
+ image_mesh.save(output_path)
218
+ return
219
+
220
+ data = image_rig.getdata()
221
+ data2 = image_mesh.getdata()
222
+ new_data = []
223
+ for item, item2 in zip(data, data2):
224
+ if item[3] == 0:
225
+ new_data.append(item2)
226
+ else:
227
+ new_data.append((int(item[0]*alpha + item2[0]*(1-alpha)), int(item[1]*alpha + item2[1]*(1-alpha)), int(item[2]*alpha + item2[2]*(1-alpha)), 255))
228
+ image_mesh.putdata(new_data)
229
+
230
+ # image_mesh.paste(image_rig, (0, 0), image_rig)
231
+
232
+ image_mesh.save(output_path)
233
+
234
+ def merge_videos(video1, video2, output_path):
235
+
236
+ # overlap two videos together, video1 is the background, video2 is the foreground
237
+ # os.system(f'ffmpeg -i {video1} -i {video2} -filter_complex "[0:v][1:v] overlay=0:0:enable=\'between(t,0,60)\'" -pix_fmt yuv420p -c:a copy {output_path}')
238
+
239
+ frames_path_1 = glob.glob(video1 + '*.png')
240
+ total_frames = len(frames_path_1)
241
+ combined_frames = []
242
+ for i in range(total_frames):
243
+ frame1 = Image.open(f'{video1}{i:04d}.png')
244
+ frame2 = Image.open(f'{video2}{i:04d}.png')
245
+ frame1.paste(frame2, (0, 0), frame2)
246
+ combined_frames.append(frame1)
247
+
248
+ # paste the combined frames on a pure white background
249
+ combined_frames_white = []
250
+ for frame in combined_frames:
251
+ white = Image.new('RGB', frame.size, (255, 255, 255))
252
+ white.paste(frame, (0, 0), frame)
253
+ combined_frames_white.append(white)
254
+
255
+ combined_frames=combined_frames_white
256
+
257
+ combined_videos = torch.stack([torch.tensor(np.array(frame)) for frame in combined_frames])[..., :3]
258
+
259
+ # write the video with high quality
260
+ # io.write_video(output_path, combined_videos, 24)
261
+ io.write_video(output_path, combined_videos, 24, video_codec='libx264', options={'crf': '18'})
262
+
263
+ # comvert the frames to mp4 video
264
+
265
+ # video = cv2.VideoWriter(output_path, cv2.VideoWriter_fourcc(*'H264'), 30, (frame1.size[0], frame1.size[1]))
266
+ # for frame in combined_frames:
267
+ # video.write(cv2.cvtColor(np.array(frame), cv2.COLOR_RGB2BGR))
268
+ # video.release()
269
+
270
+ # video_1, audio_1, fps_1 = io.read_video(video1, pts_unit="sec")
271
+ # video_2, audio_2, fps_2 = io.read_video(video2, pts_unit="sec")
272
+ # non_zero = video_2.sum(dim=-1) != 0
273
+ # non_zero = torch.stack([non_zero, non_zero, non_zero], dim=-1)
274
+ # video_1[non_zero] = video_2[non_zero]
275
+ # io.write_video(output_path, video_1, int(fps_1['video_fps']))
276
+
277
+ def add_skin(filepath, skin, bone_index, co=None, pc=None):
278
+ bpy.ops.wm.obj_import(filepath=filepath)
279
+ obj = bpy.context.object
280
+
281
+ bpy.context.view_layer.objects.active = obj
282
+ bpy.ops.object.shade_smooth()
283
+ bpy.ops.object.mode_set(mode='EDIT')
284
+ bpy.ops.mesh.select_all(action='SELECT')
285
+ bpy.ops.mesh.normals_make_consistent(inside=False)
286
+ bpy.ops.object.mode_set(mode='OBJECT')
287
+
288
+ if co is not None:
289
+ obj.parent = co
290
+
291
+ if pc is not None:
292
+ skin = np.array(skin)
293
+ pc = pc[:, :3].numpy()
294
+ y_max = pc[:, 1].max()
295
+ pc = pc + np.array([0, y_max, 0])
296
+ pc = pc / 2
297
+ new_skin = np.zeros((len(obj.data.vertices), skin.shape[1]))
298
+ for i, v in enumerate(obj.data.vertices):
299
+ v_co = np.array(v.co)
300
+
301
+ dist = np.linalg.norm(pc - v_co, axis=1)
302
+ # min_idx = np.argmin(dist)
303
+ # sort, and then get top 3 index
304
+ min_idx_list = np.argsort(dist)[:3]
305
+
306
+ for min_idx in min_idx_list:
307
+ # get inverse distance weight
308
+ interpolate_weight = np.square(1 / dist[min_idx]) / np.square(1 / dist[min_idx_list]).sum()
309
+ new_skin[i] = new_skin[i] + interpolate_weight * skin[min_idx]
310
+
311
+ skin = new_skin
312
+
313
+ color_list = skin
314
+
315
+ color_list = color_list[:,bone_index]
316
+
317
+ vertex_colors = obj.data.vertex_colors.new()
318
+
319
+ for poly in obj.data.polygons:
320
+ for loop_index in poly.loop_indices:
321
+
322
+ vertex_index = obj.data.loops[loop_index].vertex_index
323
+ # Get the weight for the vertex
324
+ weight = color_list[vertex_index]
325
+
326
+ color = cmap(weight)
327
+
328
+ # Assign the weight to the vertex color (RGBA)
329
+ vertex_colors.data[loop_index].color = color # Use the weight for RGB
330
+
331
+ # let bsdf use vertex color and then output to surface
332
+ mat = bpy.data.materials.new(name='mat')
333
+ # delete all material of obj
334
+ obj.data.materials.clear()
335
+ obj.data.materials.append(mat)
336
+ mat.use_nodes = True
337
+ mat.node_tree.nodes.clear()
338
+ vertex_color = mat.node_tree.nodes.new('ShaderNodeVertexColor')
339
+ bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfPrincipled')
340
+ output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
341
+ mat.node_tree.links.new(vertex_color.outputs['Color'], bsdf.inputs['Base Color'])
342
+ mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
343
+ mat.node_tree.nodes['Principled BSDF'].inputs['Roughness'].default_value = 0.5
344
+
345
+
346
+
347
+ def add_pc(points):
348
+ base_sphere = create_sphere((points[0][0], points[0][1], points[0][2]), size=0.003, color=cmap(0), reduced=True)
349
+ # copy the base sphere to create the rest of the spheres
350
+ for i in tqdm(range(1, points.shape[0])):
351
+ new_sphere = base_sphere.copy()
352
+ new_sphere.location = (points[i][0], points[i][1], points[i][2])
353
+ bpy.context.collection.objects.link(new_sphere)
354
+
355
+ def add_floor(back=False):
356
+ # create a plane as floor
357
+ bpy.ops.mesh.primitive_plane_add(size=50, enter_editmode=False, align='WORLD', location=(0, 20, 0))
358
+ floor = bpy.context.object
359
+ floor.name = 'floor'
360
+ # set white material for floor
361
+ mat = bpy.data.materials.new(name='floor_mat')
362
+ floor.data.materials.append(mat)
363
+ mat.use_nodes = True
364
+ mat.node_tree.nodes.clear()
365
+ bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
366
+ output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
367
+ mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
368
+ mat.node_tree.nodes['Diffuse BSDF'].inputs['Color'].default_value = (1, 1, 1, 1)
369
+
370
+ if back:
371
+ # create a plane as background
372
+ bpy.ops.mesh.primitive_plane_add(size=30, enter_editmode=False, align='WORLD', location=(0, 15, 0), rotation=(-0.5*np.pi, 0, 0))
373
+ background = bpy.context.object
374
+ background.name = 'background'
375
+ # set white material for background
376
+ mat = bpy.data.materials.new(name='background_mat')
377
+ background.data.materials.append(mat)
378
+ mat.use_nodes = True
379
+ mat.node_tree.nodes.clear()
380
+ bsdf = mat.node_tree.nodes.new('ShaderNodeBsdfDiffuse')
381
+ output = mat.node_tree.nodes.new('ShaderNodeOutputMaterial')
382
+ mat.node_tree.links.new(bsdf.outputs['BSDF'], output.inputs['Surface'])
383
+ mat.node_tree.nodes['Diffuse BSDF'].inputs['Color'].default_value = (1, 1, 1, 1)
384
+
385
+ def setup_render():
386
+ # color management
387
+ bpy.context.scene.view_settings.view_transform = 'Standard'
388
+
389
+ # set the render engine to Cycles
390
+ bpy.context.scene.render.engine = 'CYCLES'
391
+ # enable cuda
392
+ bpy.context.preferences.addons['cycles'].preferences.get_devices()
393
+ bpy.context.preferences.addons['cycles'].preferences.compute_device_type = 'CUDA'
394
+ bpy.context.scene.cycles.device = 'GPU'
395
+
396
+ # set render background to transparent
397
+ bpy.context.scene.render.film_transparent = True
398
+
399
+ def render(output_path, shadow=True, shading=True, quick=False):
400
+
401
+ if shadow:
402
+ add_floor()
403
+
404
+ if shading:
405
+ # create a sun light
406
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
407
+ light = bpy.context.object
408
+ light.data.energy = 5
409
+ # angle pointing to the origin
410
+ light.rotation_euler = (0.1*np.pi, 0, 0)
411
+ # set angle
412
+ light.data.angle = 0.08*np.pi
413
+
414
+ else:
415
+ # global illumination by create world light
416
+ world = bpy.data.worlds.new('World')
417
+ bpy.context.scene.world = world
418
+ world.use_nodes = True
419
+ world_light = world.node_tree.nodes['Background']
420
+ world_light.inputs['Strength'].default_value = 1
421
+ world_light.inputs['Color'].default_value = (1, 1, 1, 1)
422
+
423
+ # create a camera
424
+ cam = bpy.data.cameras.new("Camera")
425
+ cam_ob = bpy.data.objects.new("Camera", cam)
426
+ camera = bpy.data.objects['Camera']
427
+ bpy.context.scene.collection.objects.link(camera)
428
+ camera.location = Vector((2, -1.5, 2))
429
+ look_at = Vector((0, 0, 0.36))
430
+ # compute the rotation
431
+ camera.rotation_mode = 'QUATERNION'
432
+ camera.rotation_quaternion = (camera.location - look_at).to_track_quat('Z', 'Y')
433
+ # set size
434
+ camera.data.sensor_width = 26
435
+ # set the camera to be active
436
+ bpy.context.scene.camera = camera
437
+
438
+
439
+
440
+ # make the rendered image square
441
+ bpy.context.scene.render.resolution_x = 2048
442
+ bpy.context.scene.render.resolution_y = 2048
443
+
444
+ setup_render()
445
+
446
+ if quick:
447
+ # reduce the number of samples
448
+ bpy.context.scene.cycles.samples = 128
449
+ bpy.context.scene.cycles.preview_samples = 128
450
+ bpy.context.scene.cycles.max_bounces = 1
451
+ bpy.context.scene.cycles.min_bounces = 1
452
+ bpy.context.scene.cycles.diffuse_bounces = 1
453
+ bpy.context.scene.cycles.glossy_bounces = 1
454
+ else:
455
+ bpy.context.scene.cycles.samples = 1024
456
+ bpy.context.scene.cycles.preview_samples = 1024
457
+ bpy.context.scene.cycles.max_bounces = 4
458
+ bpy.context.scene.cycles.min_bounces = 4
459
+ bpy.context.scene.cycles.diffuse_bounces = 4
460
+ bpy.context.scene.cycles.glossy_bounces = 4
461
+
462
+ # output path
463
+ # output_path = '/home/ydengbd/objaverse/test.png'
464
+ bpy.context.scene.render.filepath = output_path
465
+ bpy.ops.render.render(write_still=True)
466
+
467
+ def render_spin(output_path, co, shadow=True, shading=True, quick=False):
468
+ # create a new coordinate system at the origin
469
+ new_co = add_co(location=(0, 0, 0), rotation=(0, 0, 0), scale=(1, 1, 1))
470
+ # set the object to be the child of the new coordinate system
471
+ co.parent = new_co
472
+
473
+ # add spin animation to the new coordinate system
474
+ new_co.rotation_mode = 'XYZ'
475
+ new_co.rotation_euler = (0, 0, 0)
476
+ new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=0)
477
+ new_co.rotation_euler = (0, 0, 2*np.pi)
478
+ new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=60)
479
+
480
+ if shadow:
481
+ add_floor()
482
+
483
+ if shading:
484
+ # create a sun light
485
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
486
+ light = bpy.context.object
487
+ light.data.energy = 5
488
+ # angle pointing to the origin
489
+ light.rotation_euler = (0.1*np.pi, 0, 0)
490
+ # set angle
491
+ light.data.angle = 0.08*np.pi
492
+
493
+ else:
494
+ # global illumination by create world light
495
+ world = bpy.data.worlds.new('World')
496
+ bpy.context.scene.world = world
497
+ world.use_nodes = True
498
+ world_light = world.node_tree.nodes['Background']
499
+ world_light.inputs['Strength'].default_value = 1
500
+ world_light.inputs['Color'].default_value = (1, 1, 1, 1)
501
+
502
+ # create a camera
503
+ cam = bpy.data.cameras.new("Camera")
504
+ cam_ob = bpy.data.objects.new("Camera", cam)
505
+ camera = bpy.data.objects['Camera']
506
+ bpy.context.scene.collection.objects.link(camera)
507
+ camera.location = Vector((2, -1.5, 2))
508
+ look_at = Vector((0, 0, 0.36))
509
+ # compute the rotation
510
+ camera.rotation_mode = 'QUATERNION'
511
+ camera.rotation_quaternion = (camera.location - look_at).to_track_quat('Z', 'Y')
512
+ # set size
513
+ camera.data.sensor_width = 26
514
+ # set the camera to be active
515
+ bpy.context.scene.camera = camera
516
+
517
+
518
+ # render the animation
519
+ bpy.context.scene.frame_start = 0
520
+ bpy.context.scene.frame_end = 60
521
+
522
+ # make the rendered image square
523
+ bpy.context.scene.render.resolution_x = 1024
524
+ bpy.context.scene.render.resolution_y = 1024
525
+
526
+ setup_render()
527
+
528
+ if quick:
529
+ # reduce the number of samples
530
+ bpy.context.scene.cycles.samples = 128
531
+ bpy.context.scene.cycles.preview_samples = 128
532
+ bpy.context.scene.cycles.max_bounces = 1
533
+ bpy.context.scene.cycles.min_bounces = 1
534
+ bpy.context.scene.cycles.diffuse_bounces = 1
535
+ bpy.context.scene.cycles.glossy_bounces = 1
536
+ else:
537
+ bpy.context.scene.cycles.samples = 512
538
+ bpy.context.scene.cycles.preview_samples = 512
539
+ bpy.context.scene.cycles.max_bounces = 4
540
+ bpy.context.scene.cycles.min_bounces = 4
541
+ bpy.context.scene.cycles.diffuse_bounces = 4
542
+ bpy.context.scene.cycles.glossy_bounces = 4
543
+
544
+ # output path
545
+ bpy.context.scene.render.filepath = output_path
546
+ if output_path.endswith('.mp4'):
547
+ # render a mp4 video
548
+ bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
549
+ bpy.context.scene.render.ffmpeg.format = 'MPEG4'
550
+ bpy.context.scene.render.ffmpeg.codec = 'H264'
551
+
552
+ bpy.ops.render.render(animation=True, write_still=True)
553
+
554
+ def setup_anim(armature, arti):
555
+ # enter pose mode
556
+ print('Arti shape', arti.shape)
557
+ bpy.ops.object.mode_set(mode='POSE')
558
+ print('total bones', len(armature.pose.bones))
559
+ for i, pose_bone in enumerate(armature.pose.bones):
560
+ pose_bone.rotation_mode = 'XYZ'
561
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=0)
562
+
563
+ pose_bone.rotation_euler = arti[i]
564
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=30)
565
+
566
+ pose_bone.rotation_euler = Vector((0, 0, 0))
567
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=60)
568
+ bpy.ops.object.mode_set(mode='OBJECT')
569
+
570
+ def render_anim(output_path, armature, arti, quick=False):
571
+ # enter pose mode
572
+ setup_anim(armature, arti)
573
+
574
+ # save blend file
575
+ # bpy.ops.wm.save_as_mainfile(filepath='/data2/ydengbd/objaverse/test.blend')
576
+
577
+ add_floor()
578
+
579
+ # create a sun light
580
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
581
+ light = bpy.context.object
582
+ light.data.energy = 5
583
+ # angle pointing to the origin
584
+ light.rotation_euler = (50/180*np.pi, 0, -20/180*np.pi)
585
+ # set angle
586
+ light.data.angle = 12/180*np.pi
587
+
588
+ # create a camera
589
+ cam = bpy.data.cameras.new("Camera")
590
+ cam_ob = bpy.data.objects.new("Camera", cam)
591
+ camera = bpy.data.objects['Camera']
592
+ bpy.context.scene.collection.objects.link(camera)
593
+ camera.location = Vector((0, -3, 1.3))
594
+ camera.rotation_euler = Vector((1.309, 0, 0))
595
+ # set size
596
+ camera.data.sensor_width = 36
597
+ # set the camera to be active
598
+ bpy.context.scene.camera = camera
599
+
600
+ # render the animation
601
+ bpy.context.scene.frame_start = 0
602
+ bpy.context.scene.frame_end = 60
603
+
604
+ # make the rendered image square
605
+ bpy.context.scene.render.resolution_x = 1920
606
+ bpy.context.scene.render.resolution_y = 1080
607
+
608
+ setup_render()
609
+
610
+ if quick:
611
+ # reduce the number of samples
612
+ bpy.context.scene.cycles.samples = 128
613
+ bpy.context.scene.cycles.preview_samples = 128
614
+ bpy.context.scene.cycles.max_bounces = 1
615
+ bpy.context.scene.cycles.min_bounces = 1
616
+ bpy.context.scene.cycles.diffuse_bounces = 1
617
+ bpy.context.scene.cycles.glossy_bounces = 1
618
+ else:
619
+ bpy.context.scene.cycles.samples = 1024
620
+ bpy.context.scene.cycles.preview_samples = 1024
621
+ bpy.context.scene.cycles.max_bounces = 4
622
+ bpy.context.scene.cycles.min_bounces = 4
623
+ bpy.context.scene.cycles.diffuse_bounces = 4
624
+ bpy.context.scene.cycles.glossy_bounces = 4
625
+
626
+ # output path
627
+ bpy.context.scene.render.filepath = output_path
628
+ if output_path.endswith('.mp4'):
629
+ # render a mp4 video
630
+ bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
631
+ bpy.context.scene.render.ffmpeg.format = 'MPEG4'
632
+ bpy.context.scene.render.ffmpeg.codec = 'H264'
633
+
634
+ bpy.ops.render.render(animation=True, write_still=True)
635
+
636
+
637
+ def render_animspin(output_path, co, armature, arti, shadow=True, shading=True, quick=False):
638
+ # enter pose mode
639
+ print('Arti shape', arti.shape)
640
+ bpy.ops.object.mode_set(mode='POSE')
641
+ print('total bones', len(armature.pose.bones))
642
+ for i, pose_bone in enumerate(armature.pose.bones):
643
+ pose_bone.rotation_mode = 'XYZ'
644
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=0)
645
+
646
+ pose_bone.rotation_euler = arti[i]
647
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=30)
648
+
649
+ pose_bone.rotation_euler = Vector((0, 0, 0))
650
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=60)
651
+
652
+ pose_bone.rotation_euler = arti[i]
653
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=90)
654
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=150)
655
+
656
+ pose_bone.rotation_euler = Vector((0, 0, 0))
657
+ pose_bone.keyframe_insert(data_path="rotation_euler", frame=180)
658
+ bpy.ops.object.mode_set(mode='OBJECT')
659
+
660
+ # create a new coordinate system at the origin
661
+ new_co = add_co(location=(0, 0, 0), rotation=(0, 0, 0), scale=(1, 1, 1))
662
+ # set the object to be the child of the new coordinate system
663
+ co.parent = new_co
664
+
665
+ # add spin animation to the new coordinate system
666
+ new_co.rotation_mode = 'XYZ'
667
+ new_co.rotation_euler = (0, 0, 0)
668
+ new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=90)
669
+ new_co.rotation_euler = (0, 0, 2*np.pi)
670
+ new_co.keyframe_insert(data_path='rotation_euler', index=2, frame=150)
671
+
672
+ if shadow:
673
+ add_floor()
674
+
675
+ if shading:
676
+ # create a sun light
677
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
678
+ light = bpy.context.object
679
+ light.data.energy = 5
680
+ # angle pointing to the origin
681
+ light.rotation_euler = (0.1*np.pi, 0, 0)
682
+ # set angle
683
+ light.data.angle = 0.08*np.pi
684
+
685
+ else:
686
+ # global illumination by create world light
687
+ world = bpy.data.worlds.new('World')
688
+ bpy.context.scene.world = world
689
+ world.use_nodes = True
690
+ world_light = world.node_tree.nodes['Background']
691
+ world_light.inputs['Strength'].default_value = 1
692
+ world_light.inputs['Color'].default_value = (1, 1, 1, 1)
693
+
694
+ # create a camera
695
+ cam = bpy.data.cameras.new("Camera")
696
+ cam_ob = bpy.data.objects.new("Camera", cam)
697
+ camera = bpy.data.objects['Camera']
698
+ bpy.context.scene.collection.objects.link(camera)
699
+ camera.location = Vector((2, -1.5, 2))
700
+ look_at = Vector((0, 0, 0.36))
701
+ # compute the rotation
702
+ camera.rotation_mode = 'QUATERNION'
703
+ camera.rotation_quaternion = (camera.location - look_at).to_track_quat('Z', 'Y')
704
+ # set size
705
+ camera.data.sensor_width = 26
706
+ # set the camera to be active
707
+ bpy.context.scene.camera = camera
708
+
709
+
710
+ # render the animation
711
+ bpy.context.scene.frame_start = 0
712
+ bpy.context.scene.frame_end = 180
713
+
714
+ # make the rendered image square
715
+ bpy.context.scene.render.resolution_x = 1024
716
+ bpy.context.scene.render.resolution_y = 1024
717
+
718
+ setup_render()
719
+
720
+ if quick:
721
+ # reduce the number of samples
722
+ bpy.context.scene.cycles.samples = 128
723
+ bpy.context.scene.cycles.preview_samples = 128
724
+ bpy.context.scene.cycles.max_bounces = 1
725
+ bpy.context.scene.cycles.min_bounces = 1
726
+ bpy.context.scene.cycles.diffuse_bounces = 1
727
+ bpy.context.scene.cycles.glossy_bounces = 1
728
+ else:
729
+ bpy.context.scene.cycles.samples = 512
730
+ bpy.context.scene.cycles.preview_samples = 512
731
+ bpy.context.scene.cycles.max_bounces = 4
732
+ bpy.context.scene.cycles.min_bounces = 4
733
+ bpy.context.scene.cycles.diffuse_bounces = 4
734
+ bpy.context.scene.cycles.glossy_bounces = 4
735
+
736
+ # output path
737
+ bpy.context.scene.render.filepath = output_path
738
+ if output_path.endswith('.mp4'):
739
+ # render a mp4 video
740
+ bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
741
+ bpy.context.scene.render.ffmpeg.format = 'MPEG4'
742
+ bpy.context.scene.render.ffmpeg.codec = 'H264'
743
+
744
+ bpy.ops.render.render(animation=True, write_still=True)
745
+
746
+ def render_scene(output_path, shadow=True):
747
+
748
+ if shadow:
749
+ add_floor()
750
+
751
+
752
+ # create a sun light
753
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
754
+ light = bpy.context.object
755
+ light.data.energy = 5
756
+ # angle pointing to the origin
757
+ light.rotation_euler = (50/180*np.pi, 0, -20/180*np.pi)
758
+ # set angle
759
+ light.data.angle = 12/180*np.pi
760
+
761
+ # create a camera
762
+ cam = bpy.data.cameras.new("Camera")
763
+ cam_ob = bpy.data.objects.new("Camera", cam)
764
+ camera = bpy.data.objects['Camera']
765
+ bpy.context.scene.collection.objects.link(camera)
766
+ camera.location = Vector((0, -10, 5))
767
+ camera.rotation_euler = Vector((1.22, 0, 0))
768
+ # set size
769
+ camera.data.sensor_width = 26
770
+ # set the camera to be active
771
+ bpy.context.scene.camera = camera
772
+
773
+
774
+
775
+ # make the rendered image square
776
+ bpy.context.scene.render.resolution_x = 1920
777
+ bpy.context.scene.render.resolution_y = 1080
778
+
779
+ setup_render()
780
+
781
+
782
+
783
+ # output path
784
+ # output_path = '/home/ydengbd/objaverse/test.png'
785
+ bpy.context.scene.render.filepath = output_path
786
+ bpy.ops.render.render(write_still=True)
787
+
788
+
789
+ def render_teaser(output_path, shadow=True, quick=False):
790
+
791
+ if shadow:
792
+ add_floor(back=True)
793
+
794
+ # create a sun light
795
+ bpy.ops.object.light_add(type='SUN', radius=1, align='WORLD', location=(-1, -1, 3))
796
+ light = bpy.context.object
797
+ light.data.energy = 5
798
+ # angle pointing to the origin
799
+ light.rotation_euler = (50/180*np.pi, 0, -20/180*np.pi)
800
+ # set angle
801
+ light.data.angle = 12/180*np.pi
802
+
803
+ # create a camera
804
+ cam = bpy.data.cameras.new("Camera")
805
+ cam_ob = bpy.data.objects.new("Camera", cam)
806
+ camera = bpy.data.objects['Camera']
807
+ bpy.context.scene.collection.objects.link(camera)
808
+ camera.location = Vector((0, -3, 1.3))
809
+ camera.rotation_euler = Vector((80/180*np.pi, 0, 0))
810
+ # set size
811
+ camera.data.sensor_width = 48
812
+ # set the camera to be active
813
+ bpy.context.scene.camera = camera
814
+
815
+ # render the animation
816
+ bpy.context.scene.frame_start = 0
817
+ bpy.context.scene.frame_end = 60
818
+
819
+ # make the rendered image square
820
+ bpy.context.scene.render.resolution_x = 2400
821
+ bpy.context.scene.render.resolution_y = 1080
822
+
823
+ setup_render()
824
+
825
+ if quick:
826
+ # reduce the number of samples
827
+ bpy.context.scene.cycles.samples = 128
828
+ bpy.context.scene.cycles.preview_samples = 128
829
+ bpy.context.scene.cycles.max_bounces = 1
830
+ bpy.context.scene.cycles.min_bounces = 1
831
+ bpy.context.scene.cycles.diffuse_bounces = 1
832
+ bpy.context.scene.cycles.glossy_bounces = 1
833
+ else:
834
+ bpy.context.scene.cycles.samples = 1024
835
+ bpy.context.scene.cycles.preview_samples = 1024
836
+ bpy.context.scene.cycles.max_bounces = 4
837
+ bpy.context.scene.cycles.min_bounces = 4
838
+ bpy.context.scene.cycles.diffuse_bounces = 4
839
+ bpy.context.scene.cycles.glossy_bounces = 4
840
+
841
+ # output path
842
+ bpy.context.scene.render.filepath = output_path
843
+ if output_path.endswith('.mp4'):
844
+ # render a mp4 video
845
+ bpy.context.scene.render.image_settings.file_format = 'FFMPEG'
846
+ bpy.context.scene.render.ffmpeg.format = 'MPEG4'
847
+ bpy.context.scene.render.ffmpeg.codec = 'H264'
848
+
849
+ bpy.ops.render.render(animation=True, write_still=True)
850
+
851
+ def setup_armature(path, tex=False, save=True):
852
+ joints_matrix = torch.load(os.path.join(path, 'joints.pt'))
853
+ connectivity = torch.load(os.path.join(path, 'conns.pt'))
854
+ skinning_weights = torch.load(os.path.join(path, 'skins.pt'))
855
+ obj_file_path = os.path.join(path, 'object.obj')
856
+
857
+ # bpy.ops.wm.obj_import(filepath=obj_file_path)
858
+ add_mesh(obj_file_path, tex=tex)
859
+ mesh_object = bpy.context.selected_objects[0]
860
+
861
+ # pack textures
862
+ bpy.ops.file.pack_all()
863
+
864
+ temp = torch.tensor(joints_matrix)[:, 1].clone()
865
+ joints_matrix[:, 1] = -joints_matrix[:, 2]
866
+ joints_matrix[:, 2] = temp
867
+
868
+ bpy.ops.object.armature_add()
869
+ armature_obj = bpy.context.object
870
+
871
+
872
+ bpy.ops.object.mode_set(mode='EDIT')
873
+ bpy.ops.armature.select_all(action='SELECT')
874
+ bpy.ops.armature.delete()
875
+
876
+ world_matrix = Matrix([[1, 0, 0, 0],
877
+ [0, 1, 0, 0],
878
+ [0, 0, 1, 0],
879
+ [0, 0, 0, 1]])
880
+ armature_obj.matrix_world = world_matrix
881
+
882
+ bone_dict = {}
883
+
884
+ i_name = 0
885
+
886
+ for i in range(len(joints_matrix)):
887
+
888
+ if connectivity[i] == i:
889
+ continue
890
+ bone_name = str(i_name)
891
+ bone = armature_obj.data.edit_bones.new(bone_name)
892
+ bone.head = joints_matrix[connectivity[i]].cpu().numpy()
893
+ bone.tail = joints_matrix[i].cpu().numpy()
894
+ bone_dict[bone_name] = bone
895
+ i_name += 1
896
+
897
+ for bone_name, bone in bone_dict.items():
898
+ # Find parent bone by checking if current bone's head matches any other bone's tail
899
+ for other_bone_name, other_bone in bone_dict.items():
900
+ if other_bone != bone and bone.head == other_bone.tail:
901
+ bone.parent = other_bone
902
+ break
903
+
904
+ assert i_name == skinning_weights.shape[1]
905
+
906
+ for i, skinning_weight in enumerate(skinning_weights):
907
+ # print("skinning_weight", skinning_weight)
908
+ vertex_index = i
909
+ for j,weight in enumerate(skinning_weight):
910
+ bone_name = str(j)
911
+ bone_weight = float(weight)
912
+
913
+ vertex_group_name = f"{bone_name}"
914
+ vertex_group = mesh_object.vertex_groups.get(vertex_group_name)
915
+ if vertex_group is None:
916
+ vertex_group = mesh_object.vertex_groups.new(name=vertex_group_name)
917
+ vertex_group.add([vertex_index], bone_weight, 'ADD')
918
+
919
+ # for obj in bpy.context.scene.objects:
920
+ # if obj.type == 'MESH':
921
+ modifier = mesh_object.modifiers.new(name="Armature", type='ARMATURE')
922
+ modifier.object = armature_obj
923
+ modifier.use_vertex_groups = True
924
+ print("Armature modifier added to mesh:", mesh_object.name)
925
+
926
+ bpy.ops.object.mode_set(mode='OBJECT')
927
+ if save:
928
+ bpy.ops.wm.save_as_mainfile(filepath= os.path.join(path, 'blender_output.blend'))
929
+
930
+ return armature_obj
931
+
932
+ def reload_tensor_skinning(data, bone_name_list):
933
+
934
+ # with open(json_file, "r") as f:
935
+ # skinning_data = json.load(f)
936
+
937
+ armature_obj = bpy.data.objects.get("Armature")
938
+ if not armature_obj:
939
+ print("Error: Armature object 'Armature' not found.")
940
+ return
941
+
942
+ # 将所有网格对象放置在骨骼对象的子集中
943
+ count = 0
944
+ for obj in bpy.context.scene.objects:
945
+ if obj.type == 'MESH':
946
+ obj.parent = armature_obj
947
+ count += 1
948
+
949
+ print("total mesh count:", count)
950
+
951
+ for obj in bpy.context.scene.objects:
952
+ vertex_index = 0
953
+ if obj.type == 'MESH':
954
+ # mesh_name = obj.name
955
+ # if mesh_name in skinning_data:
956
+ # skinning_info = skinning_data[mesh_name]
957
+ # if "weight" in skinning_info:
958
+ # print("Applying skinning data for mesh:", mesh_name)
959
+ # vertex_index = 0
960
+ # for vertex_weight in skinning_info["weight"]:
961
+ # for bone_name, weight_value in vertex_weight.items():
962
+ # vertex_group = obj.vertex_groups.get(bone_name)
963
+ # if vertex_group is None:
964
+ # vertex_group = obj.vertex_groups.new(name=bone_name)
965
+ # print("Vertex group created:", bone_name)
966
+ # vertex_group.add([vertex_index], weight_value, 'REPLACE')
967
+ # vertex_index += 1
968
+ # else:
969
+ # print("No skinning data found for mesh:", mesh_name)
970
+
971
+ for i, v in enumerate(obj.data.vertices):
972
+ v_co = np.array(v.co)
973
+ pc = data['pc'][:, :3].numpy()
974
+ y_max = pc[:, 1].max()
975
+ pc = pc + np.array([0, y_max, 0])
976
+ pc = pc / 2
977
+ dist = np.linalg.norm(pc - v_co, axis=1)
978
+ # min_idx = np.argmin(dist)
979
+ # sort, and then get top 3 index
980
+ min_idx_list = np.argsort(dist)[:3]
981
+
982
+ for min_idx in min_idx_list:
983
+ # get inverse distance weight
984
+ interpolate_weight = np.square(1 / dist[min_idx]) / np.square(1 / dist[min_idx_list]).sum()
985
+
986
+ for idx, j in enumerate(data['skins_index'][min_idx]):
987
+ if j == -1:
988
+ break
989
+ bone_name = bone_name_list[j]
990
+ vertex_group = obj.vertex_groups.get(str(int(bone_name)))
991
+ if vertex_group is None:
992
+ vertex_group = obj.vertex_groups.new(name=str(int(bone_name)))
993
+ print("Vertex group created:", bone_name)
994
+
995
+ vertex_group.add([i], interpolate_weight * data['skins_weight'][min_idx][idx], 'ADD')
996
+
997
+
998
+ for obj in bpy.context.scene.objects:
999
+ if obj.type == 'MESH':
1000
+ modifier = obj.modifiers.new(name="Armature", type='ARMATURE')
1001
+ modifier.object = armature_obj
1002
+ modifier.use_vertex_groups = True
1003
+ print("Armature modifier added to mesh:", obj.name)
1004
+
1005
+ def reload_tensor(data, root='data', save=True):
1006
+ joints_matrix = data['joints'].clone()
1007
+ connectivity = data['conns']
1008
+ obj_file_path = os.path.join(root, data['name'], 'object.obj')
1009
+
1010
+ # bpy.ops.wm.obj_import(filepath=obj_file_path)
1011
+ add_mesh(obj_file_path)
1012
+ mesh_object = bpy.context.selected_objects[0]
1013
+
1014
+ # pack textures
1015
+ bpy.ops.file.pack_all()
1016
+
1017
+ y_max = data['pc'][:, 1].max()
1018
+ joints_matrix = joints_matrix + torch.tensor([0, y_max, 0])
1019
+ joints_matrix = joints_matrix / 2
1020
+
1021
+ temp = joints_matrix[:, 1].clone()
1022
+ joints_matrix[:, 1] = -joints_matrix[:, 2]
1023
+ joints_matrix[:, 2] = temp
1024
+
1025
+ bpy.ops.object.armature_add()
1026
+ armature_obj = bpy.context.object
1027
+
1028
+
1029
+ bpy.ops.object.mode_set(mode='EDIT')
1030
+ bpy.ops.armature.select_all(action='SELECT')
1031
+ bpy.ops.armature.delete()
1032
+
1033
+ world_matrix = Matrix([[1, 0, 0, 0],
1034
+ [0, 1, 0, 0],
1035
+ [0, 0, 1, 0],
1036
+ [0, 0, 0, 1]])
1037
+ armature_obj.matrix_world = world_matrix
1038
+
1039
+ bone_dict = {}
1040
+ bone_name_list = np.zeros(data['bones_num'])
1041
+ i_name = 0
1042
+
1043
+ for i in range(len(joints_matrix)):
1044
+
1045
+ if connectivity[i] == i:
1046
+ continue
1047
+ bone_name = str(i_name)
1048
+ bone = armature_obj.data.edit_bones.new(bone_name)
1049
+ bone.head = joints_matrix[connectivity[i]].cpu().numpy()
1050
+ bone.tail = joints_matrix[i].cpu().numpy()
1051
+ bone_dict[bone_name] = bone
1052
+ for j, skinbone in enumerate(data['bones']):
1053
+ if torch.equal(skinbone[:3], data['joints'][connectivity[i]]) and torch.equal(skinbone[3:], data['joints'][i]):
1054
+ bone_name_list[j] = i_name
1055
+ i_name += 1
1056
+
1057
+ for bone_name, bone in bone_dict.items():
1058
+ # Find parent bone by checking if current bone's head matches any other bone's tail
1059
+ for other_bone_name, other_bone in bone_dict.items():
1060
+ if other_bone != bone and bone.head == other_bone.tail:
1061
+ bone.parent = other_bone
1062
+ break
1063
+
1064
+ print(bone_name_list)
1065
+
1066
+ reload_tensor_skinning(data, bone_name_list)
1067
+
1068
+ print("Armature modifier added to mesh:", mesh_object.name)
1069
+
1070
+ bpy.ops.object.mode_set(mode='OBJECT')
1071
+ if save:
1072
+ bpy.ops.wm.save_as_mainfile(filepath= os.path.join('/data2/ydengbd/Anymate/Anymate/data', data['name'], 'blender_output.blend'))
1073
+
1074
+ return armature_obj
1075
+
1076
+ def load_blender(blender_path):
1077
+
1078
+ bpy.ops.wm.read_homefile(use_empty=True)
1079
+ # bpy.ops.wm.append(directory=object_path, link=False)
1080
+ # load_object(object_path)
1081
+ bpy.ops.wm.open_mainfile(filepath=blender_path)
1082
+ armature_obj = []
1083
+ mesh_obj = []
1084
+ for obj in bpy.context.scene.objects:
1085
+ if obj.type == "ARMATURE":
1086
+ armature_obj.append(obj)
1087
+ if obj.type == "MESH":
1088
+ mesh_obj.append(obj)
1089
+
1090
+ print('mesh obj:', len(mesh_obj))
1091
+
1092
+
1093
+
1094
+ # start retrieve the information of mesh, skining and rigging
1095
+
1096
+ #1. retrieve the information of rigging, save the world matrix of the amature object
1097
+ total_armature_info = {}
1098
+ joints_matrix = []
1099
+ bone_dict = {}
1100
+ parent_name= []
1101
+ bone_count = 0
1102
+ for obj in armature_obj:
1103
+ # depsgraph = bpy.context.evaluated_depsgraph_get()
1104
+ # obj = obj.evaluated_get(depsgraph)
1105
+ armature_info = {}
1106
+ armature_info["world_matrix"] = [list(row) for row in obj.matrix_world.copy()]
1107
+ translation = obj.matrix_world.translation
1108
+ for bone in obj.pose.bones:
1109
+
1110
+ joints_matrix.append(np.array(list((obj.matrix_world.to_3x3() @ bone.head+translation).copy())))
1111
+
1112
+ if bone.parent:
1113
+ parent_name.append(bone.parent.name)
1114
+ else:
1115
+ parent_name.append('root')
1116
+ bone_dict[bone.name] = bone_count
1117
+ bone_count += 1
1118
+ connectivity = torch.zeros(bone_count, dtype=torch.int32)
1119
+
1120
+ for i, bone_name in enumerate(parent_name):
1121
+ if bone_name == 'root':
1122
+ connectivity[i] = i
1123
+ else:
1124
+ connectivity[i] = bone_dict[bone_name]
1125
+ joints_matrix = torch.from_numpy(np.array(joints_matrix))
1126
+
1127
+ skinning_weight = torch.zeros(len(mesh_obj[0].data.vertices), joints_matrix.shape[0])
1128
+
1129
+ vertex_index = 0
1130
+ for obj in mesh_obj:
1131
+ vertex_groups = obj.vertex_groups
1132
+
1133
+
1134
+ for vertex in obj.data.vertices:
1135
+ vertex_info = {}
1136
+ for group in vertex.groups:
1137
+ name = vertex_groups[group.group].name
1138
+
1139
+ weight = group.weight
1140
+ skinning_weight[vertex.index][bone_dict[name]] = weight
1141
+
1142
+ obj_save_path = blender_path.replace('.blend', '.obj')
1143
+ bpy.ops.wm.obj_export(filepath=obj_save_path, export_materials=False)
1144
+ return joints_matrix,connectivity, skinning_weight
1145
+
1146
+
1147
+ def save_scene(scene_path):
1148
+ # export the scene as a glb file
1149
+ if scene_path.endswith('.glb'):
1150
+ bpy.ops.export_scene.gltf(filepath=scene_path)
1151
+ bpy.ops.wm.save_as_mainfile(filepath=scene_path.replace('.glb', '.blend'))
1152
+ elif scene_path.endswith('.blend'):
1153
+ bpy.ops.wm.save_as_mainfile(filepath=scene_path)
1154
+ elif scene_path.endswith('.obj'):
1155
+ bpy.ops.wm.obj_export(filepath=scene_path, export_materials=False)
1156
+ else:
1157
+ raise ValueError(f"Unsupported file extension: {scene_path}")
1158
+
1159
+ if __name__ == '__main__':
1160
+ # load the mesh
1161
+ empty()
1162
+ add_mesh('/home/ydengbd/objaverse/obj/0001.obj')
1163
+ # load the joints
1164
+ joints_matrix = np.load('/home/ydengbd/objaverse/joints/0001.npy')
1165
+ add_joint(joints_matrix)
1166
+ # load the connections
1167
+ con_index = np.load('/home/ydengbd/objaverse/connections/0001.npy')
1168
+ add_conn(con_index)
1169
+ # load the skin
Anymate/utils/train_utils.py ADDED
@@ -0,0 +1,406 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ from tqdm import tqdm
4
+
5
+ import torch
6
+ import torch.backends.cudnn as cudnn
7
+ from torch.utils.tensorboard import SummaryWriter
8
+ from torch.utils.data import DataLoader
9
+
10
+ from Anymate.dataset import AnymateDataset, my_collate
11
+ from Anymate.model import EncoderDecoder
12
+ from Anymate.utils.loss_utils import cross_entropy_with_probs_batch, cos_loss, cos_loss_clamp, chamfer_distance_with_average
13
+ from Anymate.utils.vol_utils import get_co, get_gt, extract_keypoints
14
+ from torch.nn.parallel import DistributedDataParallel as DDP
15
+ from torch.distributed import init_process_group, destroy_process_group
16
+ from torch.utils.data.distributed import DistributedSampler
17
+
18
+ import point_cloud_utils as pcu
19
+ from sklearn.cluster import DBSCAN
20
+ from diffusers import DDPMScheduler, DDIMScheduler
21
+ import torch.nn.functional as F
22
+ from Anymate.utils.diffusion_utils import my_collate_diff, randn_tensor
23
+
24
+
25
+ def ddp_setup(rank: int, world_size: int, port: int):
26
+ """
27
+ Args:
28
+ rank: Unique identifier of each process
29
+ world_size: Total number of processes
30
+ """
31
+ os.environ["MASTER_ADDR"] = "localhost"
32
+ os.environ["MASTER_PORT"] = str(port)
33
+ torch.cuda.set_device(rank)
34
+ init_process_group(backend="nccl", rank=rank, world_size=world_size)
35
+
36
+ class AverageMeter(object):
37
+ """Computes and stores the average and current value"""
38
+ def __init__(self):
39
+ self.reset()
40
+
41
+ def reset(self):
42
+ self.val = 0.0
43
+ self.avg = 0.0
44
+ self.sum = 0.0
45
+ self.count = 0.0
46
+
47
+ def update(self, val, n=1):
48
+ self.val = val
49
+ self.sum += val * n
50
+ self.count += n
51
+ self.avg = self.sum / self.count
52
+
53
+ def accumulate(self, val, n=1):
54
+ self.val = val
55
+ self.sum += val
56
+ self.count += n
57
+ self.avg = self.sum / self.count
58
+
59
+ def save_checkpoint(state, is_best, checkpoint='checkpoint', filename='model_best.pth.tar', snapshot=None):
60
+ filepath = os.path.join(checkpoint, filename)
61
+ if is_best:
62
+ torch.save(state, filepath)
63
+
64
+ if snapshot and state['epoch'] % snapshot == 0:
65
+ torch.save(state, os.path.join(checkpoint, 'checkpoint_{}.pth.tar'.format(state['epoch'])))
66
+
67
+ def train_model(rank, world_size, config, args, shared_dict, port=12355):
68
+ ddp_setup(rank, world_size, port)
69
+ lowest_loss = 1e20
70
+ model_config = config['model']
71
+ model = EncoderDecoder(device=f'cuda:{rank}', dtype=torch.float32, **model_config)
72
+ model.to(f'cuda:{rank}')
73
+
74
+ if rank == 0:
75
+ print('only_embed', model.only_embed)
76
+ print('return_latents', model.return_latents)
77
+ print(model)
78
+ if not args.finetune:
79
+ model.encoder.requires_grad_(False)
80
+ model = DDP(model, device_ids=[rank])
81
+ optimizer_config = config['optimizer']
82
+ if args.finetune:
83
+ optimizer = torch.optim.Adam(model.module.parameters(), **optimizer_config)
84
+ else:
85
+ if args.encoder == 'miche':
86
+ optimizer = torch.optim.Adam(model.module.decoder.parameters(), **optimizer_config)
87
+ elif args.encoder == 'bert':
88
+ optimizer = torch.optim.Adam(list(model.module.decoder.parameters()) + list(model.module.point_proj.parameters()), **optimizer_config)
89
+ # optionally resume from a checkpoint
90
+ if args.resume:
91
+ try:
92
+ print("=> loading checkpoint '{}'".format(args.resume))
93
+ checkpoint = torch.load(args.resume)
94
+ args.start_epoch = checkpoint['epoch']
95
+ lowest_loss = checkpoint['lowest_loss']
96
+ model.module.load_state_dict(checkpoint['state_dict'], strict=True)
97
+
98
+ print("=> loaded checkpoint '{}' (epoch {})".format(args.resume, checkpoint['epoch']))
99
+ except:
100
+ print("=> no checkpoint found at '{}'".format(args.resume))
101
+
102
+ cudnn.benchmark = True
103
+ print(' Total params: %.2fM' % (sum(p.numel() for p in optimizer.param_groups[0]['params']) / 1000000.0))
104
+ my_collate_func = my_collate_diff if args.mode == 'diffusion' else my_collate
105
+ if world_size > 1:
106
+ if not args.split:
107
+ train_dataset = shared_dict['train_dataset']
108
+ train_sampler = DistributedSampler(train_dataset, num_replicas=world_size, rank=rank)
109
+ train_loader = DataLoader(train_dataset, batch_size=args.train_batch, sampler=train_sampler, collate_fn= my_collate_func)
110
+ else:
111
+ train_dataset = AnymateDataset(name=args.trainset + f'_{rank}', root=args.root) #should changed to dpp version
112
+ train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, collate_fn= my_collate_func)
113
+ else:
114
+ train_dataset = AnymateDataset(name=args.trainset, root=args.root)
115
+ train_loader = DataLoader(train_dataset, batch_size=args.train_batch, shuffle=True, collate_fn= my_collate_func)
116
+
117
+ if rank == 0:
118
+ test_loader = DataLoader(AnymateDataset(name=args.testset, root=args.root), batch_size=args.test_batch, shuffle=False, collate_fn= my_collate_func )
119
+
120
+ if not args.schedule:
121
+ args.schedule = [args.epochs//2]
122
+ scheduler = torch.optim.lr_scheduler.MultiStepLR(optimizer, args.schedule, gamma=args.gamma)
123
+ # step the scheduler to the start epoch
124
+ for _ in range(args.start_epoch):
125
+ scheduler.step()
126
+ if rank == 0:
127
+ logger = SummaryWriter(log_dir=args.logdir)
128
+ print('start ')
129
+ print('test_frequency', args.test_freq)
130
+ print('start from epoch', args.start_epoch)
131
+ # start training
132
+ for epoch in range(args.start_epoch, args.epochs):
133
+ test_dict = None
134
+ is_best = False
135
+ lr = scheduler.get_last_lr()
136
+ if rank == 0:
137
+ print('\nEpoch: %d | LR: %.8f' % (epoch + 1, lr[0]))
138
+ train_loss, grad_norm = train(train_loader, model, optimizer, args)
139
+ if rank == 0 and (epoch == 0 or (epoch+1)%args.test_freq== 0):
140
+ print('Testing epoch', epoch+1)
141
+ test_dict = test(test_loader, model, args, world_size=world_size)
142
+
143
+
144
+ scheduler.step()
145
+ if rank == 0:
146
+ print('Epoch{:d}. train_loss: {:.6f}.'.format(epoch + 1, train_loss))
147
+ print('Epoch{:d}. grad_norm: {:.6f}.'.format(epoch + 1, grad_norm))
148
+ info = {'train_loss': train_loss, 'grad_norm': grad_norm, 'lr': lr[0]}
149
+ # print('Epoch{:d}. val_loss: {:.6f}.'.format(epoch + 1, val_loss))
150
+ if test_dict is not None:
151
+ for key, value in test_dict.items():
152
+ print('Epoch{:d}. {:s}: {:.6f}.'.format(epoch + 1, key, value))
153
+
154
+ test_loss = test_dict['test loss'] if not args.mode == 'diffusion' else test_dict['chamfer']
155
+ is_best = test_loss < lowest_loss
156
+ lowest_loss = min(test_loss, lowest_loss)
157
+ for key, value in test_dict.items():
158
+ info[key] = value
159
+
160
+ for tag, value in info.items():
161
+ logger.add_scalar(tag, value, epoch+1)
162
+ save_dict = {'epoch': epoch + 1, 'state_dict': model.module.state_dict(), 'lowest_loss': lowest_loss, 'optimizer': optimizer.state_dict(), 'model_config': model_config}
163
+ save_checkpoint(save_dict, is_best=is_best, checkpoint=args.checkpoint, snapshot=args.epochs//20)
164
+
165
+ def get_criterion(args):
166
+ if args.loss == 'cos':
167
+ criterion = cos_loss
168
+ elif args.loss == 'ce':
169
+ criterion = cross_entropy_with_probs_batch
170
+ elif args.loss == 'cos_clamp':
171
+ criterion = cos_loss_clamp
172
+ else:
173
+ criterion = chamfer_distance_with_average
174
+ return criterion
175
+
176
+ def get_train_loss(model, data, args):
177
+ criterion = get_criterion(args)
178
+ loss = 0.0
179
+ if args.mode == 'skin':
180
+ y_pred, idx = model(data, downsample=1024)
181
+ y_pred = torch.softmax(y_pred, dim=-1)
182
+ y = data['skins'].to(args.device)
183
+ y = y[:, idx]
184
+ loss = criterion(y_pred, y)
185
+
186
+ elif args.mode == 'conn':
187
+ y_pred = model(data, args.device)
188
+ y_pred = torch.softmax(y_pred, dim=-1)
189
+ y = data['conns'].to(args.device)
190
+ y = y[:, :y_pred.shape[1], :y_pred.shape[1]].float()
191
+ loss = criterion(y_pred, y)
192
+
193
+ elif args.mode == 'joints': # joints mode
194
+ if args.decoder == 'transformer_latent':
195
+ y_pred = model(data, args.device)
196
+ joints_gt = data['joints'].to(args.device)
197
+ loss = 0.0
198
+ for i in range(joints_gt.shape[0]):
199
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
200
+ loss += criterion(y_pred[i:i+1], joints_gt_i.unsqueeze(0))
201
+ loss /= joints_gt.shape[0]
202
+
203
+ elif args.decoder == 'triplane' or args.decoder == 'implicit_transformer':
204
+ criterion = torch.nn.BCEWithLogitsLoss()
205
+ y_pred = model(data, args.device, downsample=True)
206
+ joints_gt = data['joints'].to(args.device)
207
+ for i in range(joints_gt.shape[0]):
208
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
209
+ vol = get_co(data['vox'][i])
210
+ if data['vox'][i].shape[0] > 50000:
211
+ vol = vol[y_pred[i][1]]
212
+ gt = get_gt(vol.to(args.device), joints_gt_i)
213
+ loss += criterion(y_pred[i][0].squeeze(-1).unsqueeze(0), gt.unsqueeze(0))
214
+ else:
215
+ gt = get_gt(vol.to(args.device), joints_gt_i)
216
+ loss += criterion(y_pred[i].squeeze(-1).unsqueeze(0), gt.unsqueeze(0))
217
+ loss /= joints_gt.shape[0]
218
+
219
+ elif args.mode == 'diffusion':
220
+ noise_scheduler = DDIMScheduler(num_train_timesteps=args.num_train_step)
221
+
222
+ samples = data['joints_repeat'].to(model.device).float()
223
+ #use 256 input joints
224
+ samples = samples[...,:args.num_training_points,:]
225
+
226
+ samples = samples.to(model.device)
227
+ noise = torch.randn(samples.shape, device=samples.device)
228
+ assert samples.device == noise.device
229
+ bs = samples.shape[0]
230
+
231
+ # Sample a random timestep for each image
232
+ timesteps = torch.randint(
233
+ 0, noise_scheduler.config.num_train_timesteps, (bs,), device=samples.device,
234
+ dtype=torch.int64
235
+ )
236
+
237
+ noisy_joints = noise_scheduler.add_noise(samples, noise, timesteps)
238
+ noisy_joints = noisy_joints.to(model.device)
239
+ noisy_joints = noisy_joints.permute(0, 2, 1)
240
+
241
+ noise_pred = model(data, noisy_joints=noisy_joints, timesteps = timesteps)
242
+ noise_pred = noise_pred.permute(0, 2, 1)
243
+ loss = F.mse_loss(noise_pred, noise)
244
+
245
+ return loss
246
+
247
+ def train(train_loader, model, optimizer, args):
248
+ if not args.finetune:
249
+ model.train()
250
+ model.module.encoder.eval()
251
+ else:
252
+ model.train()
253
+ loss_meter = AverageMeter()
254
+ grad_norm_meter = AverageMeter()
255
+
256
+ for data in tqdm(train_loader):
257
+ loss = get_train_loss(model, data, args)
258
+ optimizer.zero_grad()
259
+ loss.backward()
260
+ grad_norm = 0
261
+
262
+ for p in optimizer.param_groups[0]['params']:
263
+ grad_norm += p.grad.data.norm(2).item()
264
+ grad_norm_meter.update(grad_norm)
265
+ optimizer.step()
266
+ loss_meter.update(loss.item())
267
+
268
+ return loss_meter.avg, grad_norm_meter.avg
269
+
270
+ def test(test_loader, model, args, world_size=1):
271
+ model.eval()
272
+ assert args.mode in ['skin', 'joints', 'conn', 'diffusion'], 'mode should be choose from [skin, joints, conn, diffusion], got {}'.format(args.mode)
273
+
274
+ if args.mode == 'skin' or args.mode == 'conn':
275
+ loss_meter = AverageMeter()
276
+ cos_sim_meter = AverageMeter()
277
+ cos_clamp_meter = AverageMeter()
278
+ for i, data in enumerate(tqdm(test_loader)):
279
+ if world_size > 1 and i > 1000:
280
+ break
281
+ with torch.no_grad():
282
+ y_pred = model(data, args.device)
283
+ y_pred = torch.softmax(y_pred, dim=-1)
284
+
285
+ if args.mode == 'skin':
286
+ y = data['skins'].to(args.device)
287
+ elif args.mode == 'conn':
288
+ y = data['conns'].to(args.device)
289
+ y = y[:, :y_pred.shape[1], :y_pred.shape[1]].float()
290
+
291
+ loss = 0.0
292
+ loss = cross_entropy_with_probs_batch(y_pred, y)
293
+ loss_meter.update(loss.item())
294
+ cos_sim = cos_loss(y_pred, y)
295
+ cos_sim_meter.update(cos_sim.mean().item()) # 1 - loss.item()
296
+ cos_clamp = cos_loss_clamp(y_pred, y)
297
+ cos_clamp_meter.update(cos_clamp.mean().item())
298
+
299
+ loss_dict = {'test loss': loss_meter.avg, 'cos_sim': cos_sim_meter.avg, 'cos_clamp': cos_clamp_meter.avg}
300
+ # get the loss of the joints prediction
301
+ elif args.mode == 'joints':
302
+ if args.decoder == 'transformer_latent':
303
+ loss_meter = AverageMeter()
304
+ emd_meter = AverageMeter()
305
+ for i, data in tqdm(enumerate(test_loader)):
306
+ if world_size > 1 and i > 1000:
307
+ break
308
+ with torch.no_grad():
309
+ y_pred = model(data, args.device)
310
+ joints_gt = data['joints'].to(args.device)
311
+
312
+ loss = 0.0
313
+ emd = 0.0
314
+ for i in range(joints_gt.shape[0]):
315
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
316
+ y_pred_i = y_pred[i]
317
+
318
+ y_pred_i = y_pred[i].detach().cpu().numpy()
319
+ clustering = DBSCAN(eps=0.03, min_samples=1).fit(y_pred_i) # Consider add eps and min_samples as arguments
320
+ cluster_centers = []
321
+ for cluster in set(clustering.labels_):
322
+ cluster_centers.append(y_pred_i[clustering.labels_ == cluster].mean(axis=0))
323
+ y_pred_i = torch.from_numpy(np.array(cluster_centers)).to(args.device)
324
+
325
+ if y_pred_i.shape[0] < 2:
326
+ print(data['name'][i] + ' has less than 2 points')
327
+ continue
328
+ loss += chamfer_distance_with_average(y_pred_i.unsqueeze(0), joints_gt_i.unsqueeze(0))
329
+ emd_i, pi = pcu.earth_movers_distance(y_pred_i.cpu().numpy().astype(np.float64), joints_gt_i.cpu().numpy().astype(np.float64))
330
+ emd += emd_i
331
+ if loss == 0 or emd == 0:
332
+ continue
333
+ loss /= joints_gt.shape[0]
334
+ loss_meter.update(loss.item())
335
+ emd_meter.update(emd)
336
+ loss_dict = {'test loss': loss_meter.avg, 'emd': emd_meter.avg}
337
+
338
+ elif args.decoder == 'triplane' or 'implicit_transformer':
339
+ loss_meter = AverageMeter()
340
+ emd_meter = AverageMeter()
341
+ chamfer_meter = AverageMeter()
342
+ criterion = torch.nn.BCEWithLogitsLoss()
343
+ for data in tqdm(test_loader):
344
+ with torch.no_grad():
345
+ y_pred = model(data, args.device)
346
+ joints_gt = data['joints'].to(args.device)
347
+ loss = 0.0
348
+ emd = 0.0
349
+ chamfer = 0.0
350
+ for i in range(joints_gt.shape[0]):
351
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
352
+ vol = get_co(data['vox'][i])
353
+ gt = get_gt(vol.to(args.device), joints_gt_i)
354
+ loss += criterion(y_pred[i].squeeze(-1).unsqueeze(0), gt.unsqueeze(0))
355
+ key_points = extract_keypoints(y_pred[i].cpu(), data['vox'][i])
356
+ if len(key_points) < 2:
357
+ continue
358
+ key_points = key_points / 32 - 1
359
+ chamfer += chamfer_distance_with_average(torch.from_numpy(key_points).unsqueeze(0).to(joints_gt_i.device), joints_gt_i.unsqueeze(0))
360
+ emd_i, _ = pcu.earth_movers_distance(key_points.astype(np.float64), joints_gt_i.cpu().numpy().astype(np.float64))
361
+ emd += emd_i
362
+ if loss == 0 or emd == 0 or chamfer == 0:
363
+ continue
364
+ loss /= joints_gt.shape[0]
365
+ loss_meter.update(loss.item())
366
+ emd_meter.update(emd)
367
+ chamfer_meter.update(chamfer.item())
368
+ loss_dict = {'test loss': loss_meter.avg, 'emd': emd_meter.avg, 'chamfer': chamfer_meter.avg}
369
+
370
+ elif args.mode == 'diffusion':
371
+ loss_meter = AverageMeter()
372
+ emd_meter = AverageMeter()
373
+ chamfer_meter = AverageMeter()
374
+ generator=torch.Generator(device='cpu').manual_seed(args.seed+1)
375
+ scheduler = DDIMScheduler(num_train_timesteps=args.num_train_step)
376
+ scheduler.set_timesteps(args.num_train_step)
377
+ points_shape = [args.test_batch, args.num_training_points, 3]
378
+ for data in tqdm(test_loader):
379
+ joints_gt = data['joints'].to(dtype=torch.float64)
380
+ points_noise = randn_tensor(points_shape, generator=generator)
381
+ points = points_noise.permute(0, 2, 1).to(model.device)
382
+ for t in scheduler.timesteps:
383
+ with torch.no_grad():
384
+ time_steps = torch.ones(args.test_batch, 1, dtype=torch.long) * t
385
+ time_steps = time_steps.to(model.device)
386
+ model_output = model(data, noisy_joints=points, timesteps = time_steps)
387
+
388
+ points = scheduler.step(model_output, t, points, generator=generator).prev_sample
389
+ points = points.permute(0, 2, 1).cpu()
390
+
391
+ chamfer_sum = 0.0
392
+ emd_sum = 0.0
393
+
394
+ for i in range(args.test_batch):
395
+ joints_gt_i = joints_gt[i,:data['joints_num'][i], :3]
396
+ points_i = points[i]
397
+ points_i = points_i.reshape( -1, 3)
398
+ emd, p = pcu.earth_movers_distance(points_i.cpu().numpy(),joints_gt_i[:,:3].cpu().numpy())
399
+ emd_sum += emd
400
+ chamfer_sum += chamfer_distance_with_average(points_i.unsqueeze(0),joints_gt_i[:,:3].unsqueeze(0))
401
+
402
+ emd_meter.update(emd_sum)
403
+ chamfer_meter.update(chamfer_sum.item())
404
+ loss_dict = {'chamfer': chamfer_meter.avg, 'emd': emd_meter.avg}
405
+
406
+ return loss_dict
Anymate/utils/ui_utils.py ADDED
@@ -0,0 +1,284 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import trimesh
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ import matplotlib.pyplot as plt
6
+ import gradio as gr
7
+ import time
8
+ bone_colors = plt.get_cmap('tab10')
9
+
10
+ from Anymate.utils.utils import load_checkpoint, get_joint, get_connectivity, get_skinning
11
+ from Anymate.utils.dataset_utils import obj2mesh
12
+ from Anymate.args import anymate_args
13
+ # from Anymate.utils.render_utils import empty, add_co, add_mesh, add_joint, add_conn, add_skin, setup_armature
14
+
15
+ def visualize_results(mesh_file=None, joints=None, conns=None, skins=None):
16
+ # Create a scene with both original and processed meshes
17
+ scene = trimesh.Scene()
18
+ vis_file = mesh_file.replace('object.obj', 'vis.glb')
19
+
20
+ if mesh_file is not None:
21
+ # Load the original mesh (in blue) with transparency
22
+ # original_mesh = trimesh.load(mesh_file)
23
+ original_mesh = obj2mesh(mesh_file)
24
+ if skins is not None:
25
+ # pdb.set_trace()
26
+ # Get per-vertex colors based on skinning weights
27
+ vertex_colors = np.zeros((len(original_mesh.vertices), 4))
28
+
29
+ # Convert skinning weights to numpy if needed
30
+ if isinstance(skins, torch.Tensor):
31
+ skins = skins.cpu().numpy()
32
+
33
+ # For each bone, blend colors based on skinning weights
34
+ for bone_idx in range(skins.shape[1]):
35
+ bone_color = np.array(bone_colors(bone_idx % 10)) # Get base color for this bone
36
+ weights = skins[:, bone_idx]
37
+ vertex_colors += np.outer(weights, bone_color) # Blend weighted colors
38
+
39
+ # Normalize and clip colors
40
+ vertex_colors = np.clip(vertex_colors, 0, 1)
41
+
42
+ # Convert to vertex colors and set alpha
43
+ vertex_colors = (vertex_colors * 255).astype(np.uint8)
44
+ vertex_colors[:, 3] = 255 # Set alpha to 100 for transparency
45
+ # print(vertex_colors.shape)
46
+ # print(vertex_colors.max(axis=0), vertex_colors.min(axis=0), vertex_colors.mean(axis=0))
47
+
48
+ # Apply colors directly to vertices
49
+ original_mesh.visual.vertex_colors = vertex_colors
50
+
51
+ # face_colors = np.zeros((len(original_mesh.faces), 4))
52
+
53
+ # processed_mesh = trimesh.load(mesh_file)
54
+ processed_mesh = obj2mesh(mesh_file)
55
+ # Assign vertex colors from original_mesh to processed_mesh
56
+ # Since they might have different number of vertices, we need to find closest vertices
57
+
58
+ # Get vertices from both meshes
59
+ orig_vertices = original_mesh.vertices
60
+ proc_vertices = processed_mesh.vertices
61
+
62
+ # For each vertex in processed_mesh, find the closest vertex in original_mesh
63
+ closest_indices = []
64
+ for proc_vertex in proc_vertices:
65
+ # Calculate distances to all original vertices
66
+ distances = np.linalg.norm(orig_vertices - proc_vertex, axis=1)
67
+ # Find index of closest vertex
68
+ closest_idx = np.argmin(distances)
69
+ closest_indices.append(closest_idx)
70
+
71
+ proc_vertex_colors = original_mesh.visual.vertex_colors[closest_indices]
72
+ processed_mesh.visual.vertex_colors = proc_vertex_colors
73
+ original_mesh = processed_mesh
74
+
75
+ else:
76
+ original_mesh.visual.face_colors = [255, 255, 255, 100] # Blue with alpha=100 for transparency
77
+ scene.add_geometry(original_mesh)
78
+
79
+ if joints is not None:
80
+ # create a sphere for each joint
81
+ for position in joints:
82
+ sphere = trimesh.primitives.Sphere(radius=0.02)
83
+ sphere.visual.face_colors = [255, 0, 0, 255] # Red with transparency
84
+ sphere.apply_translation(position.cpu().numpy())
85
+ scene.add_geometry(sphere)
86
+
87
+ if conns is not None:
88
+ # create a line for each connectivity
89
+ for i, conn in enumerate(conns):
90
+ if i == conn:
91
+ continue
92
+ # Create cylinder between joints
93
+ points = [joints[i].cpu().numpy(), joints[conn].cpu().numpy()]
94
+ direction = points[1] - points[0]
95
+ height = np.linalg.norm(direction)
96
+ cylinder = trimesh.primitives.Cylinder(radius=0.01, height=height)
97
+
98
+ # Calculate rotation matrix to align cylinder with direction
99
+ direction = direction / height # Normalize direction vector
100
+ up_vector = np.array([0, 0, 1])
101
+ rotation_matrix = trimesh.geometry.align_vectors(up_vector, direction)
102
+
103
+ # Apply rotation and translation to cylinder
104
+ cylinder.apply_transform(rotation_matrix)
105
+ cylinder.apply_translation(points[0] + direction * height/2)
106
+
107
+ cylinder.visual.face_colors = [0, 0, 255, 255] # Blue
108
+ scene.add_geometry(cylinder)
109
+
110
+ # Export the scene
111
+ scene.export(vis_file)
112
+ return vis_file
113
+
114
+
115
+ def process_mesh_to_pc(obj_path, sample_num = 8192, save_path = None):
116
+ # mesh_list : list of trimesh
117
+ try :
118
+ mesh = trimesh.load_mesh(obj_path)
119
+
120
+ points, face_idx = mesh.sample(sample_num, return_index=True)
121
+ normals = mesh.face_normals[face_idx]
122
+
123
+ pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
124
+
125
+
126
+ if save_path is not None:
127
+ np.save(save_path, pc_normal)
128
+
129
+ return pc_normal
130
+ except Exception as e:
131
+ print(f"Error: {obj_path} {e}")
132
+ return None
133
+
134
+
135
+ def normalize_mesh(mesh):
136
+ # Check if input is a scene with multiple meshes
137
+ if isinstance(mesh, trimesh.Scene):
138
+ # Combine all meshes in the scene into a single mesh
139
+ meshes = []
140
+ for geometry in mesh.geometry.values():
141
+ if isinstance(geometry, trimesh.Trimesh):
142
+ # Transform mesh to scene coordinates
143
+ transform = mesh.graph[mesh.graph.nodes_geometry[0]][0]
144
+ geometry.apply_transform(transform)
145
+ meshes.append(geometry)
146
+
147
+ # Combine all meshes
148
+ mesh = trimesh.util.concatenate(meshes)
149
+
150
+ # Get vertices and compute bounding box
151
+ vertices = mesh.vertices
152
+ bbox_min = vertices.min(axis=0)
153
+ bbox_max = vertices.max(axis=0)
154
+
155
+ # Find center and scale
156
+ center = (bbox_min + bbox_max) * 0.5
157
+ scale = 2.0 / (bbox_max - bbox_min).max()
158
+
159
+ # Center and scale vertices
160
+ vertices = (vertices - center) * scale
161
+
162
+ # Create new mesh with normalized vertices
163
+ normalized_mesh = trimesh.Trimesh(vertices=vertices,
164
+ faces=mesh.faces,
165
+ face_normals=mesh.face_normals,
166
+ vertex_normals=mesh.vertex_normals,
167
+ process=False)
168
+
169
+ # # Copy texture from original mesh if it exists
170
+ # if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'material'):
171
+ # print("copy material")
172
+ # normalized_mesh.visual.material = mesh.visual.material
173
+ # if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'texture'):
174
+ # print("copy texture")
175
+ # normalized_mesh.visual.texture = mesh.visual.texture
176
+ # if hasattr(mesh, 'visual') and hasattr(mesh.visual, 'uv'):
177
+ # print("copy uv")
178
+ # normalized_mesh.visual.uv = mesh.visual.uv
179
+
180
+ return normalized_mesh
181
+
182
+
183
+ def vis_joint(normalized_mesh_file, joints):
184
+ if normalized_mesh_file is None or joints is None:
185
+ return None, None
186
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints)
187
+ return vis_file, vis_file
188
+
189
+ def vis_connectivity(normalized_mesh_file, joints, conns):
190
+ if normalized_mesh_file is None or joints is None or conns is None:
191
+ return None, None
192
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, conns=conns)
193
+ return vis_file, vis_file
194
+
195
+ def vis_skinning(normalized_mesh_file, joints, conns, skins):
196
+ if normalized_mesh_file is None or joints is None or conns is None or skins is None:
197
+ return None, None
198
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, conns=conns, skins=skins)
199
+ return vis_file, vis_file
200
+
201
+ def prepare_blender_file(normalized_mesh_file):
202
+ if normalized_mesh_file is None:
203
+ return None
204
+
205
+ if not os.path.exists(normalized_mesh_file) or not os.path.exists(normalized_mesh_file.replace('object.obj', 'joints.pt')) or not os.path.exists(normalized_mesh_file.replace('object.obj', 'conns.pt')) or not os.path.exists(normalized_mesh_file.replace('object.obj', 'skins.pt')):
206
+ return None
207
+
208
+ folder = normalized_mesh_file.replace('object.obj', '')
209
+ abs_folder = os.path.abspath(folder)
210
+ os.system(f"python Render.py --path {abs_folder}")
211
+
212
+ blender_file = os.path.join(folder, 'blender_output.blend')
213
+ while not os.path.exists(blender_file):
214
+ time.sleep(1)
215
+
216
+ return blender_file
217
+
218
+
219
+ def process_input(mesh_file):
220
+ """
221
+ Function to handle input changes and initialize visualization
222
+
223
+ Args:
224
+ mesh_file: Path to input mesh file
225
+ joint_checkpoint: Path to joint prediction checkpoint
226
+ conn_checkpoint: Path to connectivity prediction checkpoint
227
+ skin_checkpoint: Path to skinning prediction checkpoint
228
+
229
+ Returns:
230
+ vis_file: Path to visualization file
231
+ """
232
+
233
+ # For now just visualize the input mesh
234
+ if mesh_file is None:
235
+ return None, None, None, None, None, None, None, None
236
+
237
+ # make folder for tmp files
238
+ os.makedirs(f"Anymate/tmp/{mesh_file.split('/')[-1].replace('.obj', '')}", exist_ok=True)
239
+
240
+ normalized_mesh = normalize_mesh(obj2mesh(mesh_file))
241
+ normalized_mesh_file = f"Anymate/tmp/{mesh_file.split('/')[-1].replace('.obj', '')}/object.obj"
242
+ normalized_mesh.export(normalized_mesh_file)
243
+
244
+ # normalized_mesh.export(mesh_file)
245
+
246
+ vis_file = visualize_results(mesh_file=normalized_mesh_file)
247
+ pc = process_mesh_to_pc(normalized_mesh_file)
248
+ pc = torch.from_numpy(pc).to(anymate_args.device).to(torch.float32)
249
+
250
+ # print(pc.shape, pc.max(dim=0), pc.min(dim=0))
251
+
252
+ return normalized_mesh_file, vis_file, vis_file, None, pc, None, None, None
253
+
254
+
255
+ def get_model(checkpoint):
256
+ model = load_checkpoint(checkpoint, anymate_args.device, anymate_args.num_joints)
257
+ return model, True
258
+
259
+ def get_result_joint(mesh_file, model, pc, eps=0.03, min_samples=1):
260
+ return get_joint(pc, model, device=anymate_args.device, save=mesh_file.replace('object.obj', 'joints.pt'), eps=eps, min_samples=min_samples)
261
+
262
+ def get_result_connectivity(mesh_file, model, pc, joints):
263
+ return get_connectivity(pc, joints, model, device=anymate_args.device, save=mesh_file.replace('object.obj', 'conns.pt'))
264
+
265
+ def get_result_skinning(mesh_file, model, pc, joints, conns):
266
+ # mesh = trimesh.load(mesh_file)
267
+ mesh = obj2mesh(mesh_file)
268
+ vertices = torch.from_numpy(mesh.vertices).to(anymate_args.device).to(torch.float32)
269
+ vertex_normals = torch.from_numpy(mesh.vertex_normals).to(anymate_args.device).to(torch.float32)
270
+ vertices = torch.cat([vertices, vertex_normals], dim=-1)
271
+ return get_skinning(pc, joints, conns, model, vertices=vertices, device=anymate_args.device, save=mesh_file.replace('object.obj', 'skins.pt'))
272
+
273
+ def get_all_models(checkpoint_joint, checkpoint_conn, checkpoint_skin):
274
+ model_joint = load_checkpoint(checkpoint_joint, anymate_args.device, anymate_args.num_joints)
275
+ model_connectivity = load_checkpoint(checkpoint_conn, anymate_args.device, anymate_args.num_joints)
276
+ model_skin = load_checkpoint(checkpoint_skin, anymate_args.device, anymate_args.num_joints)
277
+ return model_joint, model_connectivity, model_skin, True, True, True
278
+
279
+ def get_all_results(mesh_file, model_joint, model_connectivity, model_skin, pc, eps=0.03, min_samples=1):
280
+ joints = get_result_joint(mesh_file, model_joint, pc, eps=eps, min_samples=min_samples)
281
+ conns = get_result_connectivity(mesh_file, model_connectivity, pc, joints)
282
+ skins = get_result_skinning(mesh_file, model_skin, pc, joints, conns)
283
+ return joints, conns, skins
284
+
Anymate/utils/ui_utils_bpy.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import trimesh
2
+ import numpy as np
3
+ import torch
4
+
5
+ from Anymate.utils.utils import load_checkpoint, get_joints, get_connectivity
6
+ from Anymate.args import anymate_args
7
+ from Anymate.utils.render_utils import empty, add_co, add_mesh, add_joints, add_conn, add_skin, setup_armature, save_scene
8
+
9
+ def visualize_results(mesh_file=None, joints=None, connectivity=None, skinning=None):
10
+
11
+ import bpy
12
+ # Create a scene with both original and processed meshes
13
+ vis_file = "Anymate/tmp/vis_scene.glb"
14
+ print('fffffffff')
15
+
16
+ # empty()
17
+ bpy.ops.wm.read_homefile(use_empty=True)
18
+
19
+ if mesh_file is not None:
20
+ # add_mesh(mesh_file)
21
+ bpy.ops.wm.obj_import(filepath=mesh_file)
22
+
23
+ if joints is not None:
24
+ add_joints(joints)
25
+
26
+ if connectivity is not None:
27
+ add_conn(connectivity, joints)
28
+
29
+ if skinning is not None:
30
+ add_skin(mesh_file, skinning)
31
+
32
+ # setup_armature()
33
+ # save_scene(vis_file)
34
+ bpy.ops.wm.save_as_mainfile(filepath=vis_file)
35
+ return vis_file
36
+
37
+
38
+ def process_mesh_to_pc(obj_path, sample_num = 8192, save_path = None):
39
+ # mesh_list : list of trimesh
40
+ try :
41
+ mesh = trimesh.load_mesh(obj_path)
42
+
43
+ points, face_idx = mesh.sample(sample_num, return_index=True)
44
+ normals = mesh.face_normals[face_idx]
45
+
46
+ pc_normal = np.concatenate([points, normals], axis=-1, dtype=np.float16)
47
+
48
+
49
+ if save_path is not None:
50
+ np.save(save_path, pc_normal)
51
+
52
+ return pc_normal
53
+ except Exception as e:
54
+ print(f"Error: {obj_path} {e}")
55
+ return None
56
+
57
+
58
+ def normalize_mesh(mesh):
59
+ # Get vertices and compute bounding box
60
+ vertices = mesh.vertices
61
+ bbox_min = vertices.min(axis=0)
62
+ bbox_max = vertices.max(axis=0)
63
+
64
+ # Find center and scale
65
+ center = (bbox_min + bbox_max) * 0.5
66
+ scale = 2.0 / (bbox_max - bbox_min).max()
67
+
68
+ # Center and scale vertices
69
+ vertices = (vertices - center) * scale
70
+
71
+ # Create new mesh with normalized vertices
72
+ normalized_mesh = trimesh.Trimesh(vertices=vertices,
73
+ faces=mesh.faces,
74
+ face_normals=mesh.face_normals,
75
+ vertex_normals=mesh.vertex_normals)
76
+
77
+ return normalized_mesh
78
+
79
+
80
+ def vis_joint(normalized_mesh_file, joints):
81
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints)
82
+ return vis_file
83
+
84
+ def vis_connectivity(normalized_mesh_file, joints, connectivity):
85
+ vis_file = visualize_results(mesh_file=normalized_mesh_file, joints=joints, connectivity=connectivity)
86
+ return vis_file
87
+
88
+ def vis_skinning(skinning):
89
+ vis_file = visualize_results(skinning=skinning)
90
+ return vis_file
91
+
92
+
93
+ def process_input(mesh_file):
94
+ """
95
+ Function to handle input changes and initialize visualization
96
+
97
+ Args:
98
+ mesh_file: Path to input mesh file
99
+ joint_checkpoint: Path to joint prediction checkpoint
100
+ conn_checkpoint: Path to connectivity prediction checkpoint
101
+ skin_checkpoint: Path to skinning prediction checkpoint
102
+
103
+ Returns:
104
+ vis_file: Path to visualization file
105
+ """
106
+
107
+ # For now just visualize the input mesh
108
+
109
+ normalized_mesh = normalize_mesh(trimesh.load(mesh_file))
110
+ normalized_mesh_file = "Anymate/tmp/normalized_mesh.obj"
111
+ normalized_mesh.export(normalized_mesh_file)
112
+ vis_file = visualize_results(mesh_file=normalized_mesh_file)
113
+ pc = process_mesh_to_pc(normalized_mesh_file)
114
+ pc = torch.from_numpy(pc).to(anymate_args.device).to(torch.float32)
115
+
116
+ print(pc.shape, pc.max(dim=0), pc.min(dim=0))
117
+
118
+ return normalized_mesh_file, vis_file, pc, None, None, None
119
+
120
+
121
+ def get_model(checkpoint):
122
+ model = load_checkpoint(checkpoint, anymate_args.device, anymate_args.num_joints)
123
+ return model, True
124
+
125
+ def get_result_joint(model, pc):
126
+ return get_joints(pc, model, anymate_args.device)
127
+
128
+ def get_result_connectivity(model, pc, joints):
129
+ return get_connectivity(pc, joints, model, anymate_args.device)
130
+
131
+ def get_result_skinning(model, pc):
132
+ with torch.no_grad():
133
+ skinning = model(pc)
134
+ return skinning
Anymate/utils/utils.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from Anymate.model import EncoderDecoder
3
+ from sklearn.cluster import DBSCAN
4
+
5
+ def load_checkpoint(path, device, num_joints):
6
+ print(f"Loading model from {path}")
7
+ model_state = torch.load(path)
8
+ model_weights = model_state['state_dict']
9
+
10
+ try:
11
+ model_config = model_state['model_config']
12
+ model = EncoderDecoder(device=device, dtype=torch.float32, **model_config)
13
+ model.to(device)
14
+ model.load_state_dict(model_weights, strict=True)
15
+ except:
16
+ encoder = path.split('/')[-1].split('.')[0].split('-')[0]
17
+ decoder = path.split('/')[-1].split('.')[0].split('-')[1]
18
+ model = EncoderDecoder(encoder=encoder, decoder=decoder, device=device, dtype=torch.float32, num_joints=num_joints)
19
+ model.to(device)
20
+ model.load_state_dict(model_weights, strict=True)
21
+
22
+ print(f"Loaded model from {path}")
23
+
24
+ return model
25
+
26
+ def get_joint(pc, model, device='cuda', save=None, vox=None, eps=0.03, min_samples=1):
27
+ model.eval()
28
+ data = {'points_cloud': pc.unsqueeze(0)}
29
+ if vox is not None:
30
+ data['vox'] = vox.unsqueeze(0)
31
+ with torch.no_grad():
32
+ model.decoder.inference_mode(eps=eps, min_samples=min_samples)
33
+ joints = model(data, device=device)
34
+ joints = torch.tensor(joints, dtype=torch.float32).to(device)
35
+
36
+ if save is not None:
37
+ torch.save(joints, save)
38
+
39
+ return joints
40
+
41
+ def get_connectivity(pc, joints, model, device='cuda',return_prob=False, save=None):
42
+ model.eval()
43
+ data = {'points_cloud': pc.unsqueeze(0), 'joints': joints.unsqueeze(0), 'joints_num': torch.tensor([joints.shape[0]]),
44
+ 'joints_mask': torch.ones(joints.shape[0], device=device).unsqueeze(0)}
45
+ with torch.no_grad():
46
+ conns = model(data, device=device).softmax(dim=-1)
47
+ conns = conns.squeeze(0) if return_prob else torch.argmax(conns, dim=-1).squeeze(0)
48
+
49
+ if save is not None:
50
+ torch.save(conns, save)
51
+
52
+ return conns
53
+
54
+ def get_skinning(pc, joints, conns, model, vertices=None, bones=None, device='cuda', save=None):
55
+ model.eval()
56
+
57
+ if bones is None:
58
+ bones = []
59
+ for i in range(joints.shape[0]):
60
+ if conns[i] != i:
61
+ bones.append(torch.cat((joints[conns[i]], joints[i]), dim=-1))
62
+ bones = torch.stack(bones, dim=0)
63
+
64
+ data = {'points_cloud': pc.unsqueeze(0), 'bones': bones.unsqueeze(0), 'bones_num': torch.tensor([bones.shape[0]]),
65
+ 'bones_mask': torch.ones(bones.shape[0], device=device).unsqueeze(0)}
66
+
67
+ if vertices is not None:
68
+ data['vertices'] = vertices.unsqueeze(0)
69
+ model.decoder.inference = True
70
+
71
+ with torch.no_grad():
72
+ skins = model(data, device=device).softmax(dim=-1).squeeze(0)
73
+
74
+ if save is not None:
75
+ torch.save(skins, save)
76
+
77
+ return skins
Anymate/utils/vol_utils.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from ThirdParty.michelangelo.graphics.primitives import generate_dense_grid_points
4
+ from sklearn.cluster import DBSCAN
5
+
6
+ def get_vol(bounds=(-0.5, 0.0, -0.5, 0.5, 1.0, 0.5), octree_depth=6):
7
+
8
+ bbox_min = np.array(bounds[0:3])
9
+ bbox_max = np.array(bounds[3:6])
10
+ bbox_size = bbox_max - bbox_min
11
+
12
+ xyz_samples, grid_size, length = generate_dense_grid_points(
13
+ bbox_min=bbox_min,
14
+ bbox_max=bbox_max,
15
+ octree_depth=octree_depth,
16
+ indexing="ij"
17
+ )
18
+ xyz_samples = torch.FloatTensor(xyz_samples) # ((2^d)+1)^3
19
+
20
+ return xyz_samples
21
+
22
+ def get_co(vox, bounds=(-1.0, -1.0, -1.0, 1.0, 1.0, 1.0), dtype = torch.float32):
23
+
24
+ bbox_min = torch.tensor(bounds[0:3]).to(vox.device)
25
+ bbox_max = torch.tensor(bounds[3:6]).to(vox.device)
26
+ bbox_size = bbox_max - bbox_min
27
+
28
+ # ind = torch.argwhere(vox)
29
+ # ind = ind.to(dtype) / (vox.shape[0]) * bbox_size + bbox_min
30
+ ind = vox
31
+ ind = ind.to(dtype) / 64 * bbox_size + bbox_min
32
+
33
+ return ind.to(dtype)
34
+
35
+ def get_gt(vol, joints, octree_depth=6):
36
+ sigma = 2 / 2**octree_depth
37
+
38
+ dist = torch.cdist(vol, joints)
39
+ # print(dist.min(), dist.max())
40
+
41
+ dist = dist.min(dim=1).values
42
+
43
+ gt = torch.exp(-dist**2 / 2 / sigma**2)
44
+
45
+ return gt
46
+
47
+ def project_onto_planes(planes, coordinates):
48
+ """
49
+ Does a projection of a 3D point onto a batch of 2D planes,
50
+ returning 2D plane coordinates.
51
+
52
+ Takes plane axes of shape n_planes, 3, 3
53
+ # Takes coordinates of shape N, M, 3
54
+ # returns projections of shape N*n_planes, M, 2
55
+ """
56
+ N, M, C = coordinates.shape
57
+ n_planes, _, _ = planes.shape
58
+ coordinates = coordinates.unsqueeze(1).expand(-1, n_planes, -1, -1).reshape(N*n_planes, M, 3)
59
+ inv_planes = torch.linalg.inv(planes).unsqueeze(0).expand(N, -1, -1, -1).reshape(N*n_planes, 3, 3)
60
+ projections = torch.bmm(coordinates, inv_planes)
61
+ return projections[..., :2]
62
+
63
+ def sample_from_planes(plane_axes, plane_features, coordinates, mode='bilinear', padding_mode='zeros', box_warp=None):
64
+ assert padding_mode == 'zeros'
65
+ N, n_planes, C, H, W = plane_features.shape
66
+ _, M, _ = coordinates.shape
67
+ plane_features = plane_features.view(N*n_planes, C, H, W)
68
+
69
+ # coordinates = (2/box_warp) * coordinates # TODO: add specific box bounds
70
+
71
+ projected_coordinates = project_onto_planes(plane_axes, coordinates).unsqueeze(1)
72
+ output_features = torch.nn.functional.grid_sample(plane_features, projected_coordinates.float(), mode=mode, padding_mode=padding_mode, align_corners=False).permute(0, 3, 2, 1).reshape(N, n_planes, M, C)
73
+ return output_features
74
+
75
+ def generate_planes():
76
+ """
77
+ Defines planes by the three vectors that form the "axes" of the
78
+ plane. Should work with arbitrary number of planes and planes of
79
+ arbitrary orientation.
80
+ """
81
+ return torch.tensor([[[1, 0, 0],
82
+ [0, 1, 0],
83
+ [0, 0, 1]],
84
+ [[1, 0, 0],
85
+ [0, 0, 1],
86
+ [0, 1, 0]],
87
+ [[0, 0, 1],
88
+ [1, 0, 0],
89
+ [0, 1, 0]]], dtype=torch.float32)
90
+
91
+ def extract_keypoints(y_pred, vox):
92
+
93
+ y_pred = y_pred.detach().cpu().numpy()
94
+ vox = vox.detach().cpu().numpy()
95
+ volume = np.zeros([64, 64, 64])
96
+ volume[...] = -100
97
+ volume[vox[:, 0], vox[:, 1], vox[:, 2]] = y_pred.squeeze(-1)
98
+
99
+ clusters = []
100
+ cluster_model = DBSCAN(eps=1.8, min_samples=1)
101
+
102
+ level = min((0.85 * y_pred.max() + 0.15 * y_pred.min()).item(), 0)
103
+ potential_points = np.argwhere(volume >= level)
104
+ clustering = cluster_model.fit(potential_points)
105
+ for cluster in set(clustering.labels_):
106
+ if cluster == -1:
107
+ print('got noise', len(potential_points[clustering.labels_ == cluster]))
108
+ continue
109
+ clusters.append(potential_points[clustering.labels_ == cluster])
110
+
111
+ while True:
112
+ if np.all(np.array([(len(cluster) < 10) for cluster in clusters])):
113
+ break
114
+ new_clusters = []
115
+ for points in clusters:
116
+ if len(points) < 10:
117
+ new_clusters.append(points)
118
+ continue
119
+
120
+ value = volume[points[:, 0], points[:, 1], points[:, 2]]
121
+
122
+ potential_points = points[value >= (0.1 * value.max() + 0.9 * value.min())]
123
+ clustering = cluster_model.fit(potential_points)
124
+ for cluster in set(clustering.labels_):
125
+ if cluster == -1:
126
+ print('got noise', len(potential_points[clustering.labels_ == cluster]))
127
+ continue
128
+ new_clusters.append(potential_points[clustering.labels_ == cluster])
129
+
130
+ clusters = new_clusters
131
+
132
+ key_points = np.array([cluster.mean(axis=0) for cluster in clusters])
133
+ key_points = key_points / 32 - 1
134
+
135
+ return key_points
Render.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import bpy
3
+ import mathutils
4
+ from Anymate.utils.render_utils import empty, setup_armature
5
+
6
+
7
+ def parse_args():
8
+ parser = argparse.ArgumentParser(description='Anymate rendering script')
9
+ parser.add_argument('--path', type=str, required=True, help='Path to the model file')
10
+ return parser.parse_args()
11
+
12
+ args = parse_args()
13
+
14
+ print(f"Starting converting {args.path} to blender format...")
15
+
16
+ empty()
17
+ setup_armature(args.path)
ThirdParty/PointLLM/.gitignore ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ __pycache__
2
+ *.egg-info
3
+ .vscode
4
+ checkpoints
5
+ outputs
6
+ wandb
7
+ anno_data
8
+ objaverse_data
9
+ modelnet40_data
10
+ *.zip
11
+ *.ipynb
12
+ serving_workdirs
ThirdParty/PointLLM/README.md ADDED
@@ -0,0 +1,353 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <br>
2
+ <p align="center">
3
+ <h1 align="center"><img src="assets/icon.png" align="center" width="6.5%"><strong>PointLLM: Empowering Large Language Models to Understand Point Clouds</strong></h1>
4
+ <p align="center">
5
+ <a href='https://runsenxu.com/' target='_blank'>Runsen Xu</a>&emsp;
6
+ <a href='https://guanfang12.github.io/' target='_blank'>Xiaolong Wang</a>&emsp;
7
+ <a href='https://tai-wang.github.io/' target='_blank'>Tai Wang</a>&emsp;
8
+ <a href='http://yilunchen.com/about' target='_blank'>Yilun Chen</a>&emsp;
9
+ <a href='https://oceanpang.github.io/' target='_blank'>Jiangmiao Pang*</a>&emsp;
10
+ <a href='http://dahua.site/' target='_blank'>Dahua Lin</a>&emsp;
11
+ <br>
12
+ The Chinese University of Hong Kong&emsp;Shanghai AI Laboratory&emsp;Zhejiang University
13
+ </p>
14
+ </p>
15
+
16
+ <p align="center">
17
+ <a href="http://arxiv.org/abs/2308.16911" target='_**blank**'>
18
+ <img src="https://img.shields.io/badge/arXiv-2308.16911-blue?">
19
+ </a>
20
+ <a href="https://arxiv.org/pdf/2308.16911.pdf" target='_blank'>
21
+ <img src="https://img.shields.io/badge/Paper-📖-blue?">
22
+ </a>
23
+ <a href="https://runsenxu.com/projects/PointLLM" target='_blank'>
24
+ <img src="https://img.shields.io/badge/Project-&#x1F680-blue">
25
+ </a>
26
+ <a href="http://101.230.144.196" target='_blank'>
27
+ <img src="https://img.shields.io/badge/Demo-&#x1f917-blue">
28
+ </a>
29
+ <a href="" target='_blank'>
30
+ <img src="https://visitor-badge.laobi.icu/badge?page_id=OpenRobotLab.pointllm&left_color=gray&right_color=blue">
31
+ </a>
32
+ <a href="https://openxlab.org.cn/apps/detail/openxlab-app/PointLLM" target='_blank'>
33
+ <img src="https://cdn-static.openxlab.org.cn/app-center/openxlab_app.svg">
34
+ </a>
35
+ </p>
36
+
37
+ ## 🏠 About
38
+ <!-- ![Teaser](assets/teaser.jpg) -->
39
+ <div style="text-align: center;">
40
+ <img src="assets/teaser.jpg" alt="Dialogue_Teaser" width=100% >
41
+ </div>
42
+ We introduce <b>PointLLM, a multi-modal large language model capable of understanding colored point clouds of objects.</b> It perceives object types, geometric structures, and appearance without concerns for ambiguous depth, occlusion, or viewpoint dependency. <b>We collect a novel dataset comprising 660K simple and 70K complex point-text instruction pairs</b> to enable a two-stage training strategy. To rigorously evaluate our model's perceptual abilities and its generalization capabilities, <b>we establish two benchmarks: Generative 3D Object Classification and 3D Object Captioning, assessed through three different evaluation methods.</b>
43
+
44
+ ## 🔥 News
45
+ - [2024-09-06] We have uploaded the camera-ready version of PointLLM for ECCV 2024, which includes clearer writing and additional experimental results. Please check the paper [here](https://arxiv.org/abs/2308.16911).
46
+ - [2024-07-01] PointLLM has been accepted by ECCV 2024 with all "strong-accept" recommendation. 🎉 We are looking for self-motivated students to conduct research regarding PointLLM. Please send an email to [email protected] with your CV if you are interested!
47
+ - [2023-12-29] We release the codes of our online Gradio demo.
48
+ - [2023-12-26] We release the codes for model evaluation, including ChatGPT/GPT-4 evaluation and traditional metric evaluation.
49
+ - [2023-12-08] We release the codes for training and PointLLM-v1.2. The online demo has also been upgraded to the v1.2 version. Please enjoy! &#x1F389;
50
+ - [2023-12-01] We have released an updated version of our paper (v2), which includes additional baseline comparisons, enhanced human-evaluation metrics, improved model performance (PointLLM-v1.2), and other refinements. Please check the updated version [here](https://arxiv.org/abs/2308.16911).
51
+ - [2023-10-18] We release our instruction-following data, including both the simple-description and complex instructions. Download [here](https://huggingface.co/datasets/RunsenXu/PointLLM).
52
+ - [2023-09-26] We release the inferencing codes with checkpoints as well as the Objaverse colored point cloud files we use. You can chat with PointLLM with your own machines.
53
+ - [2023-08-31] We release the [paper](http://arxiv.org/abs/2308.16911) of PointLLM and an online gradio [demo](http://101.230.144.196). Try it! &#x1F389;
54
+
55
+ <!-- contents with emoji -->
56
+ ## 📋 Contents
57
+ - [🤖 Online Demo](#-online-demo)
58
+ - [💬 Dialogue Examples](#-dialogue-examples)
59
+ - [🔍 Overview](#-overview)
60
+ - [📦 Training and Evaluation](#-training-and-evaluation)
61
+ - [📝 TODO List](#-todo-list)
62
+ - [🔗 Citation](#-citation)
63
+ - [📄 License](#-license)
64
+ - [📚 Related Work](#-related-work)
65
+ - [👏 Acknowledgements](#-acknowledgements)
66
+
67
+ ## 🤖 Online Demo
68
+ <b>PointLLM is online! Try it at [http://101.230.144.196](http://101.230.144.196) or at [OpenXLab/PointLLM](https://openxlab.org.cn/apps/detail/openxlab-app/PointLLM).</b>
69
+
70
+ You can chat with PointLLM about the models of the [Objaverse](https://objaverse.allenai.org) dataset or about your own point clouds!
71
+
72
+ Please do not hesitate to tell us if you have any feedback! 😃
73
+
74
+ ## 💬 Dialogue Examples
75
+ | Dialogue 1 | Dialogue 2| Dialogue 3 | Dialogue 4
76
+ | :-: | :-: | :-: | :-: |
77
+ | <img width="100%" src="assets/dialogue_1.jpg"> | <img width="100%" src="assets/dialogue_2.jpg"> | <img width="100%" src="assets/dialogue_3.jpg"> | <img width="100%" src="assets/dialogue_4.jpg"> |
78
+
79
+
80
+ ## 🔍 Overview
81
+
82
+ ### Model
83
+ <p align="center">
84
+ <img src="assets/model.jpg" align="center" width="100%">
85
+ </p>
86
+ The point encoder extracts features from the input point cloud and projects them to the latent space of the LLM backbone. The LLM backbone processes sequences of point tokens and text tokens, and generates the predicted tokens as the output.
87
+
88
+ ### Experiment Results
89
+ #### Quantitative Comparisons with baselines.
90
+ Please refer to our paper for more results.
91
+ <p align="center">
92
+ <img src="assets/cls_results.png" align="center" width="100%">
93
+ </p>
94
+ <p align="center">
95
+ <img src="assets/caption_results.png" align="center" width="100%">
96
+ </p>
97
+ <b>!!!Note: Traditional metrics such as BLEU-1, ROUGE-L, and METEOR tend to favor shorter responses and may not effectively capture semantic accuracy. For a detailed discussion on this, please refer to our paper. We suggest the community not solely rely on these metrics for evaluation.</b>
98
+
99
+ #### Qualitative Comparisons with baselines.
100
+ Please refer to our paper for more results.
101
+ <p align="center">
102
+ <img src="assets/qualitative_comparisons_v2.png" align="center" width="100%">
103
+ </p>
104
+
105
+ ## 📦 Training and Evaluation
106
+ ### Installation
107
+ We test our codes under the following environment:
108
+ - Ubuntu 20.04
109
+ - NVIDIA Driver: 515.65.01
110
+ - CUDA 11.7
111
+ - Python 3.10.13
112
+ - PyTorch 2.0.1
113
+ - Transformers 4.28.0.dev(transformers.git@cae78c46)
114
+
115
+ To start:
116
+ 1. Clone this repository.
117
+ ```bash
118
+ git clone [email protected]:OpenRobotLab/PointLLM.git
119
+ cd PointLLM
120
+ ```
121
+ 2. Install packages
122
+ ```bash
123
+ conda create -n pointllm python=3.10 -y
124
+ conda activate pointllm
125
+ pip install --upgrade pip # enable PEP 660 support
126
+ pip install -e .
127
+
128
+ # * for training
129
+ pip install ninja
130
+ pip install flash-attn
131
+ ```
132
+
133
+ ### Data Preparation
134
+ #### Objaverse Training Data
135
+ 1. Download the two compressed files of 660K Objaverse colored point clouds [here](https://huggingface.co/datasets/RunsenXu/PointLLM/tree/main). They require about 77GB of storage space.
136
+ 2. Run the following command to merge the two files into one and uncompress it. This will produce a folder named `8192_npy` containing 660K point cloud files named `{Objaverse_ID}_8192.npy`. Each file is a numpy array with dimensions (8192, 6), where the first three dimensions are `xyz` and the last three dimensions are `rgb` in [0, 1] range.
137
+ ```bash
138
+ cat Objaverse_660K_8192_npy_split_a* > Objaverse_660K_8192_npy.tar.gz
139
+ tar -xvf Objaverse_660K_8192_npy.tar.gz
140
+ ```
141
+ 3. In `PointLLM` folder, create a folder `data` and create a soft link to the uncompressed file in the directory.
142
+ ```bash
143
+ cd PointLLM
144
+ mkdir data
145
+ ln -s /path/to/8192_npy data/objaverse_data
146
+ ```
147
+
148
+ #### Instruction-Following Data
149
+ 1. In `PointLLM/data` folder, create a directory named `anno_data`.
150
+ 2. Our instruction-following data, including both the simple-description and complex instructions, can be downloaded [here](https://huggingface.co/datasets/RunsenXu/PointLLM). If you have difficulty downloading the data (e.g. network issue), please email the authors.
151
+ - The simple-description data has 660K samples and the complex instructions have 70K samples.
152
+ - Both training data are based on the Objaverse dataset.
153
+ - The complex instructions are generated with GPT-4.
154
+ 3. Put the data files in the `anno_data` directory. The directory should look like this:
155
+ ```bash
156
+ PointLLM/data/anno_data
157
+ ├── PointLLM_brief_description_660K_filtered.json
158
+ ├── PointLLM_brief_description_660K.json
159
+ └── PointLLM_complex_instruction_70K.json
160
+ ```
161
+ 4. Note, the `PointLLM_brief_description_660K_filtered.json` is filtered from `PointLLM_brief_description_660K.json` by removing the 3000 objects we reserved as the validation set. If you want to reproduce the results in our paper, you should use the `PointLLM_brief_description_660K_filtered.json` for training. The `PointLLM_complex_instruction_70K.json` contains objects from the training set.
162
+ 5. If you want to generate the complex instructions by yourself, please refer to our paper for other details. The system prompt is at `pointllm/data/data_generation/system_prompt_gpt4_0613.txt`.
163
+
164
+ #### Evaluation Data
165
+ 1. Download the referencing GT `PointLLM_brief_description_val_200_GT.json` we use for the benchmarks on Objaverse dataset [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/PointLLM_brief_description_val_200_GT.json), and put it in `PointLLM/data/anno_data`. We also provide the 3000 object ids we filter during training [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/val_object_ids_3000.txt) and their corresponding referencing GT [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/PointLLM_brief_description_val_3000_GT.json), which can be used to evaluate on all the 3000 objects.
166
+ 2. Create a directory named `modelnet40_data` in `PointLLM/data`. Download the test split of ModelNet40 point clouds `modelnet40_test_8192pts_fps.dat` [here](https://huggingface.co/datasets/RunsenXu/PointLLM/blob/main/modelnet40_test_8192pts_fps.dat) and put it in `PointLLM/data/modelnet40_data`.
167
+
168
+ ### Training
169
+ #### Download the Initial LLM and Point Encoder Weights
170
+ 1. In `PointLLM` folder, create a directory named `checkpoints`.
171
+ 2. Download the pre-trained LLM and point encoder: [
172
+ PointLLM_7B_v1.1_init](https://huggingface.co/RunsenXu/PointLLM_7B_v1.1_init/tree/main) or [PointLLM_13B_v1.1_init](https://huggingface.co/RunsenXu/PointLLM_13B_v1.1_init/tree/main). Put them in the `checkpoints` directory.
173
+ 3. Note that the above "v1.1" means we use the Vicuna-v1.1 checkpoints, and you do **not** need to download the original LLaMA weights again.
174
+
175
+ #### Start Training
176
+ 1. For stage-1 training, simply run:
177
+ ```bash
178
+ cd PointLLM
179
+ scripts/PointLLM_train_stage1.sh
180
+ ```
181
+ 2. After stage-1 training, start stage-2 training:
182
+ ```bash
183
+ scripts/PointLLM_train_stage2.sh
184
+ ```
185
+
186
+ #### PointLLM-v1.1 and PointLLM-v1.2
187
+ Usually, you do not have to care about the following contents. They are only for reproducing the results in our v1 paper (PointLLM-v1.1). If you want to compare with our models or use our models for downstream tasks, please use PointLLM-v1.2 (refer to our v2 paper), which has better performance.
188
+ <details>
189
+ <summary>The following steps are for reproducing PointLLM-v1.1 (click to expand)</summary>
190
+
191
+ 1. PointLLM v1.1 and v1.2 use slightly different pre-trained point encoders and projectors. If you want to reproduce PointLLM v1.1, edit the `config.json` file in the directory of initial LLM and point encoder weights, for example, `vim checkpoints/PointLLM_7B_v1.1_init/config.json`.
192
+
193
+ 2. Change the key `"point_backbone_config_name"` to specify another point encoder config:
194
+ ```bash
195
+ # change from
196
+ "point_backbone_config_name": "PointTransformer_8192point_2layer" # v1.2
197
+ # to
198
+ "point_backbone_config_name": "PointTransformer_base_8192point", # v1.1
199
+ ```
200
+
201
+ 3. Edit the checkpoint path of the point encoder in `scripts/train_stage1.sh`:
202
+ ```bash
203
+ # change from
204
+ point_backbone_ckpt=$model_name_or_path/point_bert_v1.2.pt # v1.2
205
+ # to
206
+ point_backbone_ckpt=$model_name_or_path/point_bert_v1.1.pt # v1.1
207
+ ```
208
+ </details>
209
+
210
+ ### Chatting
211
+ 1. The trained model checkpoints are available [here](https://huggingface.co/RunsenXu) (including different versions of PointLLM).
212
+ 2. Run the following command to launch a chatbot using the `torch.float32` data type for chatting about 3D models of Objaverse. The model checkpoints will be downloaded automatically. You can also manually download the model checkpoints and specify their paths. Here is an example:
213
+ ```bash
214
+ cd PointLLM
215
+ PYTHONPATH=$PWD python pointllm/eval/PointLLM_chat.py --model_name RunsenXu/PointLLM_7B_v1.2 --data_name data/objaverse_data --torch_dtype float32
216
+ ```
217
+ 3. You can also easily modify the codes for using point clouds other than those from Objaverse, as long as the point clouds input to the model have dimensions (N, 6), where the first three dimensions are `xyz` and the last three dimensions are `rgb` (in [0, 1] range). You may sample the point clouds to have 8192 points, as our model is trained on such point clouds.
218
+ 4. The following table shows GPU requirements for different models and data types. We recommend using `torch.bfloat16` if applicable, which is used in the experiments in our paper.
219
+
220
+ | Model | Data Type | GPU Memory |
221
+ |:--------:|:---------:|:----------:|
222
+ | PointLLM-7B | torch.float16 | 14GB |
223
+ | PointLLM-7B | torch.float32 | 28GB |
224
+ | PointLLM-13B | torch.float16 | 26GB |
225
+ | PointLLM-13B | torch.float32 | 52GB |
226
+
227
+ ### Gradio Demo
228
+ 1. We provide the codes for our online Gradio demo. You can run the following commands to launch the demo locally for chatting and visualization.
229
+ ```bash
230
+ cd PointLLM
231
+ PYTHONPATH=$PWD python pointllm/eval/chat_gradio.py --model_name RunsenXu/PointLLM_7B_v1.2 --data_name data/objaverse_data
232
+ ```
233
+ 2. Kind remind: if you want to release the demo in public, please refer to https://www.gradio.app/guides/sharing-your-app#security-and-file-access.
234
+
235
+ ### Evaluation
236
+ #### Inferencing
237
+ 1. Run the following commands to infer the results.
238
+ 2. Different commands for inferencing on different benchmarks (PointLLM_7B_v1.2 as an example):
239
+ ```bash
240
+ cd PointLLM
241
+ export PYTHONPATH=$PWD
242
+
243
+ # Open Vocabulary Classification on Objaverse
244
+ python pointllm/eval/eval_objaverse.py --model_name RunsenXu/PointLLM_7B_v1.2 --task_type classification --prompt_index 0 # or --prompt_index 1
245
+
246
+ # Object captioning on Objaverse
247
+ python pointllm/eval/eval_objaverse.py --model_name RunsenXu/PointLLM_7B_v1.2 --task_type captioning --prompt_index 2
248
+
249
+ # Close-set Zero-shot Classification on ModelNet40
250
+ python pointllm/eval/eval_modelnet_cls.py --model_name RunsenXu/PointLLM_7B_v1.2 --prompt_index 0 # or --prompt_index 1
251
+ ```
252
+ 3. Please check the default command-line arguments of these two scripts. You can specify different prompts, data paths, and other parameters.
253
+ 4. After inferencing, the results will be saved in `{model_name}/evaluation` as a dict with the following format:
254
+ ```bash
255
+ {
256
+ "prompt": "",
257
+ "results": [
258
+ {
259
+ "object_id": "",
260
+ "ground_truth": "",
261
+ "model_output": "",
262
+ "label_name": "" # only for classification on modelnet40
263
+ }
264
+ ]
265
+ }
266
+ ```
267
+
268
+ #### ChatGPT/GPT-4 Evaluation
269
+ 1. Get your OpenAI API key at [https://platform.openai.com/api-keys](https://platform.openai.com/api-keys).
270
+ 2. Run the following commands to evaluate the model outputs in parallel with ChatGPT/GPT-4 (which cost approximately $1.5 to $2.2 USD).
271
+ ```bash
272
+ cd PointLLM
273
+ export PYTHONPATH=$PWD
274
+ export OPENAI_API_KEY=sk-****
275
+
276
+ # Open Vocabulary Classification on Objaverse
277
+ python pointllm/eval/evaluator.py --results_path /path/to/model_output --model_type gpt-4-0613 --eval_type open-free-form-classification --parallel --num_workers 15
278
+
279
+ # Object captioning on Objaverse
280
+ python pointllm/eval/evaluator.py --results_path /path/to/model_output --model_type gpt-4-0613 --eval_type object-captioning --parallel --num_workers 15
281
+
282
+ # Close-set Zero-shot Classification on ModelNet40
283
+ python pointllm/eval/evaluator.py --results_path /path/to/model_output --model_type gpt-3.5-turbo-0613 --eval_type modelnet-close-set-classification --parallel --num_workers 15
284
+ ```
285
+ 3. The evaluation script supports interruption and resumption. You can interrupt the evaluation process at any time by using `Ctrl+C`. This will save the temporary results. If an error occurs during the evaluation, the script will also save the current state. You can resume the evaluation from where it left off by running the same command again.
286
+ 4. The evaluation results will be saved in `{model_name}/evaluation` as another dict.
287
+ Some of the metrics are explained as follows:
288
+ ```bash
289
+ "average_score": The GPT-evaluated captioning score we report in our paper.
290
+ "accuracy": The classification accuracy we report in our paper, including random choices made by ChatGPT when model outputs are vague or ambiguous and ChatGPT outputs "INVALID".
291
+ "clean_accuracy": The classification accuracy after removing those "INVALID" outputs.
292
+ "total_predictions": The number of predictions.
293
+ "correct_predictions": The number of correct predictions.
294
+ "invalid_responses": The number of "INVALID" outputs by ChatGPT.
295
+
296
+ # Some other statistics for calling OpenAI API
297
+ "prompt_tokens": The total number of tokens of the prompts for ChatGPT/GPT-4.
298
+ "completion_tokens": The total number of tokens of the completion results from ChatGPT/GPT-4.
299
+ "GPT_cost": The API cost of the whole evaluation process, in US Dollars 💵.
300
+ ```
301
+ 5. <b>Open-Step Evaluation.</b> You can also start evaluation immediately after inferencing by passing the `--start_eval` flag and specifying the `--gpt_type`. For example:
302
+ ```bash
303
+ python pointllm/eval/eval_objaverse.py --model_name RunsenXu/PointLLM_7B_v1.2 --task_type classification --prompt_index 0 --start_eval --gpt_type gpt-4-0613
304
+ ```
305
+
306
+ #### Traditional Metric Evaluation
307
+ 1. For the object captioning task, run the following command to evaluate model outputs with traditional metrics including BLEU, ROUGE, METEOR, Sentence-BERT, and SimCSE.
308
+ ```bash
309
+ python pointllm/eval/traditional_evaluator.py --results_path /path/to/model_captioning_output
310
+ ```
311
+ 2. Note that we recommend not using BLEU, ROUGE, and METEOR for evaluation as they favor short captions and fall short of capturing semantic accuracy and diversity.
312
+
313
+ ## 📝 TODO List
314
+ - [x] Add inferencing codes with checkpoints.
315
+ - [x] Release instruction-following data.
316
+ - [x] Add training codes.
317
+ - [x] Add evaluation codes.
318
+ - [x] Add gradio demo codes.
319
+
320
+ Community contributions are welcome!👇 If you need any support, please feel free to open an issue or contact us.
321
+ - [ ] Support Phi-2 LLM to make PointLLM more accessible to the community.
322
+ - [ ] Support Chinese LLMs like InternLM.
323
+
324
+ ## 🔗 Citation
325
+
326
+ If you find our work and this codebase helpful, please consider starring this repo 🌟 and cite:
327
+
328
+ ```bibtex
329
+ @article{xu2023pointllm,
330
+ title={PointLLM: Empowering Large Language Models to Understand Point Clouds},
331
+ author={Xu, Runsen and Wang, Xiaolong and Wang, Tai and Chen, Yilun and Pang, Jiangmiao and Lin, Dahua},
332
+ journal={arXiv preprint arXiv:2308.16911},
333
+ year={2023}
334
+ }
335
+ ```
336
+
337
+ ## 📄 License
338
+ <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/"><img alt="Creative Commons License" style="border-width:0" src="https://i.creativecommons.org/l/by-nc-sa/4.0/80x15.png" /></a>
339
+ <br />
340
+ This work is under the <a rel="license" href="http://creativecommons.org/licenses/by-nc-sa/4.0/">Creative Commons Attribution-NonCommercial-ShareAlike 4.0 International License</a>.
341
+
342
+ ## 📚 Related Work
343
+ Together, Let's make LLM for 3D great!
344
+ - [Point-Bind & Point-LLM](https://arxiv.org/abs/2309.00615): aligns point clouds with Image-Bind, and leverages ImageBind-LLM to reason multi-modality input without 3D-instruction data training.
345
+ - [3D-LLM](https://arxiv.org/abs/2307.12981): employs 2D foundation models to encode multi-view images of 3D point clouds.
346
+
347
+
348
+ ## 👏 Acknowledgements
349
+ - [LLaVA](https://github.com/haotian-liu/LLaVA): Our codebase is built upon LLaVA.
350
+ - [Vicuna](https://github.com/lm-sys/FastChat): We use the Vicuna-7B and Vicuna-13B checkpoints.
351
+ - [Objaverse](https://objaverse.allenai.org): We use models of the Objaverse dataset for training and evaluation.
352
+ - [Cap3D](https://github.com/crockwell/Cap3D/): We use the Cap3D captioning data for our data generation.
353
+ - [ULIP-2](https://github.com/salesforce/ULIP): We use ULIP-2 for pre-training our point cloud encoder.
ThirdParty/PointLLM/__init__.py ADDED
File without changes
ThirdParty/PointLLM/pointllm/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # from .model import PointLLMLlamaForCausalLM
ThirdParty/PointLLM/pointllm/conversation.py ADDED
@@ -0,0 +1,375 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ from enum import auto, Enum
3
+ from typing import List, Tuple
4
+
5
+
6
+ class SeparatorStyle(Enum):
7
+ """Different separator style."""
8
+ SINGLE = auto()
9
+ TWO = auto()
10
+ MPT = auto()
11
+
12
+
13
+ @dataclasses.dataclass
14
+ class Conversation:
15
+ """A class that keeps all conversation history."""
16
+ system: str
17
+ roles: List[str]
18
+ messages: List[List[str]]
19
+ offset: int
20
+ sep_style: SeparatorStyle = SeparatorStyle.SINGLE
21
+ sep: str = "###"
22
+ sep2: str = None
23
+ version: str = "Unknown"
24
+
25
+ skip_next: bool = False
26
+
27
+ def reset(self):
28
+ self.messages = self.messages[:self.offset]
29
+
30
+ def get_prompt(self):
31
+ if self.sep_style == SeparatorStyle.SINGLE:
32
+ ret = self.system + self.sep
33
+ for role, message in self.messages:
34
+ if message:
35
+ if type(message) is tuple:
36
+ message, _, _ = message
37
+ ret += role + ": " + message + self.sep
38
+ else:
39
+ ret += role + ":"
40
+ return ret
41
+ elif self.sep_style == SeparatorStyle.TWO:
42
+ seps = [self.sep, self.sep2]
43
+ ret = self.system + seps[0]
44
+ for i, (role, message) in enumerate(self.messages):
45
+ if message:
46
+ if type(message) is tuple:
47
+ message, _, _ = message
48
+ ret += role + ": " + message + seps[i % 2]
49
+ else:
50
+ ret += role + ":"
51
+ return ret
52
+ if self.sep_style == SeparatorStyle.MPT:
53
+ ret = self.system + self.sep
54
+ for role, message in self.messages:
55
+ if message:
56
+ if type(message) is tuple:
57
+ message, _, _ = message
58
+ ret += role + message + self.sep
59
+ else:
60
+ ret += role
61
+ return ret
62
+ else:
63
+ raise ValueError(f"Invalid style: {self.sep_style}")
64
+
65
+ def append_message(self, role, message):
66
+ self.messages.append([role, message])
67
+
68
+ def pop_last_none_message(self):
69
+ # * pop the last message if it's None, this is used for multi-round dialogue
70
+ if self.messages[-1][1] is None:
71
+ self.messages.pop()
72
+
73
+ def get_images(self, return_pil=False):
74
+ images = []
75
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
76
+ if i % 2 == 0:
77
+ if type(msg) is tuple:
78
+ import base64
79
+ from io import BytesIO
80
+ from PIL import Image
81
+ msg, image, image_process_mode = msg
82
+ if image_process_mode == "Pad":
83
+ def expand2square(pil_img, background_color=(122, 116, 104)):
84
+ width, height = pil_img.size
85
+ if width == height:
86
+ return pil_img
87
+ elif width > height:
88
+ result = Image.new(pil_img.mode, (width, width), background_color)
89
+ result.paste(pil_img, (0, (width - height) // 2))
90
+ return result
91
+ else:
92
+ result = Image.new(pil_img.mode, (height, height), background_color)
93
+ result.paste(pil_img, ((height - width) // 2, 0))
94
+ return result
95
+ image = expand2square(image)
96
+ elif image_process_mode == "Crop":
97
+ pass
98
+ elif image_process_mode == "Resize":
99
+ image = image.resize((224, 224))
100
+ else:
101
+ raise ValueError(f"Invalid image_process_mode: {image_process_mode}")
102
+ max_hw, min_hw = max(image.size), min(image.size)
103
+ aspect_ratio = max_hw / min_hw
104
+ max_len, min_len = 800, 400
105
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
106
+ longest_edge = int(shortest_edge * aspect_ratio)
107
+ W, H = image.size
108
+ if H > W:
109
+ H, W = longest_edge, shortest_edge
110
+ else:
111
+ H, W = shortest_edge, longest_edge
112
+ image = image.resize((W, H))
113
+ if return_pil:
114
+ images.append(image)
115
+ else:
116
+ buffered = BytesIO()
117
+ image.save(buffered, format="JPEG")
118
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
119
+ images.append(img_b64_str)
120
+ return images
121
+
122
+ def to_gradio_chatbot(self):
123
+ ret = []
124
+ for i, (role, msg) in enumerate(self.messages[self.offset:]):
125
+ if i % 2 == 0:
126
+ if type(msg) is tuple:
127
+ import base64
128
+ from io import BytesIO
129
+ msg, image, image_process_mode = msg
130
+ max_hw, min_hw = max(image.size), min(image.size)
131
+ aspect_ratio = max_hw / min_hw
132
+ max_len, min_len = 800, 400
133
+ shortest_edge = int(min(max_len / aspect_ratio, min_len, min_hw))
134
+ longest_edge = int(shortest_edge * aspect_ratio)
135
+ W, H = image.size
136
+ if H > W:
137
+ H, W = longest_edge, shortest_edge
138
+ else:
139
+ H, W = shortest_edge, longest_edge
140
+ image = image.resize((W, H))
141
+ # image = image.resize((224, 224))
142
+ buffered = BytesIO()
143
+ image.save(buffered, format="JPEG")
144
+ img_b64_str = base64.b64encode(buffered.getvalue()).decode()
145
+ img_str = f'<img src="data:image/png;base64,{img_b64_str}" alt="user upload image" />'
146
+ msg = msg.replace('<image>', img_str)
147
+ ret.append([msg, None])
148
+ else:
149
+ ret[-1][-1] = msg
150
+ return ret
151
+
152
+ def copy(self):
153
+ return Conversation(
154
+ system=self.system,
155
+ roles=self.roles,
156
+ messages=[[x, y] for x, y in self.messages],
157
+ offset=self.offset,
158
+ sep_style=self.sep_style,
159
+ sep=self.sep,
160
+ sep2=self.sep2)
161
+
162
+ def dict(self):
163
+ if len(self.get_images()) > 0:
164
+ return {
165
+ "system": self.system,
166
+ "roles": self.roles,
167
+ "messages": [[x, y[0] if type(y) is tuple else y] for x, y in self.messages],
168
+ "offset": self.offset,
169
+ "sep": self.sep,
170
+ "sep2": self.sep2,
171
+ }
172
+ return {
173
+ "system": self.system,
174
+ "roles": self.roles,
175
+ "messages": self.messages,
176
+ "offset": self.offset,
177
+ "sep": self.sep,
178
+ "sep2": self.sep2,
179
+ }
180
+
181
+
182
+ conv_v1 = Conversation(
183
+ system="A chat between a curious human and an artificial intelligence assistant. "
184
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
185
+ roles=("Human", "Assistant"),
186
+ messages=(
187
+ ("Human", "Give three tips for staying healthy."),
188
+ ("Assistant",
189
+ "Sure, here are three tips for staying healthy:\n"
190
+ "1. Exercise regularly: Regular physical activity can help improve your overall health and wellbeing. "
191
+ "It can also help reduce your risk of chronic conditions such as obesity, diabetes, heart disease, "
192
+ "and certain cancers. Aim for at least 150 minutes of moderate-intensity aerobic exercise or "
193
+ "75 minutes of vigorous-intensity aerobic exercise per week, along with muscle-strengthening "
194
+ "activities at least two days per week.\n"
195
+ "2. Eat a balanced diet: Eating a balanced diet that is rich in fruits, "
196
+ "vegetables, whole grains, lean proteins, and healthy fats can help support "
197
+ "your overall health. Try to limit your intake of processed and high-sugar foods, "
198
+ "and aim to drink plenty of water throughout the day.\n"
199
+ "3. Get enough sleep: Getting enough quality sleep is essential for your physical "
200
+ "and mental health. Adults should aim for seven to nine hours of sleep per night. "
201
+ "Establish a regular sleep schedule and try to create a relaxing bedtime routine to "
202
+ "help improve the quality of your sleep.")
203
+ ),
204
+ offset=2,
205
+ sep_style=SeparatorStyle.SINGLE,
206
+ sep="###",
207
+ )
208
+
209
+ conv_v1_2 = Conversation(
210
+ system="A chat between a curious human and an artificial intelligence assistant. "
211
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
212
+ roles=("Human", "Assistant"),
213
+ messages=(
214
+ ("Human", "What are the key differences between renewable and non-renewable energy sources?"),
215
+ ("Assistant",
216
+ "Renewable energy sources are those that can be replenished naturally in a relatively "
217
+ "short amount of time, such as solar, wind, hydro, geothermal, and biomass. "
218
+ "Non-renewable energy sources, on the other hand, are finite and will eventually be "
219
+ "depleted, such as coal, oil, and natural gas. Here are some key differences between "
220
+ "renewable and non-renewable energy sources:\n"
221
+ "1. Availability: Renewable energy sources are virtually inexhaustible, while non-renewable "
222
+ "energy sources are finite and will eventually run out.\n"
223
+ "2. Environmental impact: Renewable energy sources have a much lower environmental impact "
224
+ "than non-renewable sources, which can lead to air and water pollution, greenhouse gas emissions, "
225
+ "and other negative effects.\n"
226
+ "3. Cost: Renewable energy sources can be more expensive to initially set up, but they typically "
227
+ "have lower operational costs than non-renewable sources.\n"
228
+ "4. Reliability: Renewable energy sources are often more reliable and can be used in more remote "
229
+ "locations than non-renewable sources.\n"
230
+ "5. Flexibility: Renewable energy sources are often more flexible and can be adapted to different "
231
+ "situations and needs, while non-renewable sources are more rigid and inflexible.\n"
232
+ "6. Sustainability: Renewable energy sources are more sustainable over the long term, while "
233
+ "non-renewable sources are not, and their depletion can lead to economic and social instability.\n")
234
+ ),
235
+ offset=2,
236
+ sep_style=SeparatorStyle.SINGLE,
237
+ sep="###",
238
+ )
239
+
240
+ conv_vicuna_v1_1 = Conversation(
241
+ system="A chat between a curious user and an artificial intelligence assistant. "
242
+ "The assistant gives helpful, detailed, and polite answers to the user's questions.",
243
+ roles=("USER", "ASSISTANT"),
244
+ version="v1",
245
+ messages=(),
246
+ offset=0,
247
+ sep_style=SeparatorStyle.TWO,
248
+ sep=" ",
249
+ sep2="</s>",
250
+ )
251
+
252
+ conv_mpt = Conversation(
253
+ system="""<|im_start|>system
254
+ - You are a helpful language and vision assistant.
255
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
256
+ - You should follow the instructions carefully and explain your answers in detail.""",
257
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
258
+ version="mpt",
259
+ messages=(),
260
+ offset=0,
261
+ sep_style=SeparatorStyle.MPT,
262
+ sep="<|im_end|>",
263
+ )
264
+
265
+ conv_mpt_text = Conversation(
266
+ system="""<|im_start|>system
267
+ - You are a helpful assistant chatbot trained by MosaicML.
268
+ - You answer questions.
269
+ - You are excited to be able to help the user, but will refuse to do anything that could be considered harmful to the user.
270
+ - You are more than just an information source, you are also able to write poetry, short stories, and make jokes.""",
271
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
272
+ version="mpt",
273
+ messages=(),
274
+ offset=0,
275
+ sep_style=SeparatorStyle.MPT,
276
+ sep="<|im_end|>",
277
+ )
278
+
279
+ conv_bair_v1 = Conversation(
280
+ system="BEGINNING OF CONVERSATION:",
281
+ roles=("USER", "GPT"),
282
+ messages=(),
283
+ offset=0,
284
+ sep_style=SeparatorStyle.TWO,
285
+ sep=" ",
286
+ sep2="</s>",
287
+ )
288
+
289
+ simple_conv = Conversation(
290
+ system="A chat between a curious human and an artificial intelligence assistant. "
291
+ "The assistant gives helpful, detailed, and polite answers to the human's questions.",
292
+ roles=("Human", "Assistant"),
293
+ messages=(
294
+ ("Human", "Hi!"),
295
+ ("Assistant", "Hi there! How can I help you today?")
296
+ ),
297
+ offset=2,
298
+ sep_style=SeparatorStyle.SINGLE,
299
+ sep="###",
300
+ )
301
+
302
+ simple_conv_multimodal = Conversation(
303
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
304
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
305
+ "Follow the instructions carefully and explain your answers in detail.",
306
+ roles=("Human", "Assistant"),
307
+ messages=(
308
+ ("Human", "Hi!"),
309
+ ("Assistant", "Hi there! How can I help you today?\n")
310
+ ),
311
+ offset=2,
312
+ sep_style=SeparatorStyle.SINGLE,
313
+ sep="###",
314
+ )
315
+
316
+ simple_conv_mpt_multimodal = Conversation(
317
+ system="""<|im_start|>system
318
+ - You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab.
319
+ - You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language.
320
+ - You should follow the instructions carefully and explain your answers in detail.""",
321
+ roles=("<|im_start|>user\n", "<|im_start|>assistant\n"),
322
+ version="mpt",
323
+ messages=(),
324
+ offset=0,
325
+ sep_style=SeparatorStyle.MPT,
326
+ sep="<|im_end|>",
327
+ )
328
+
329
+ simple_conv_legacy = Conversation(
330
+ system="You are LLaVA, a large language model trained by UW Madison WAIV Lab."
331
+ "You are designed to assist human with a variety of tasks using natural language."
332
+ "Follow the instructions carefully.",
333
+ roles=("Human", "Assistant"),
334
+ messages=(
335
+ ("Human", "Hi!\n\n### Response:"),
336
+ ("Assistant", "Hi there! How can I help you today?\n")
337
+ ),
338
+ offset=2,
339
+ sep_style=SeparatorStyle.SINGLE,
340
+ sep="###",
341
+ )
342
+
343
+ conv_llava_v1 = Conversation(
344
+ system="You are LLaVA, a large language and vision assistant trained by UW Madison WAIV Lab."
345
+ "You are able to understand the visual content that the user provides, and assist the user with a variety of tasks using natural language."
346
+ "Follow the instructions carefully and explain your answers in detail.",
347
+ roles=("USER", "ASSISTANT"),
348
+ version="v1",
349
+ messages=(),
350
+ offset=0,
351
+ sep_style=SeparatorStyle.TWO,
352
+ sep=" ",
353
+ sep2="</s>",
354
+ )
355
+
356
+ default_conversation = conv_v1_2
357
+ conv_templates = {
358
+ "default": conv_v1_2,
359
+ "simple": simple_conv,
360
+ "simple_legacy": simple_conv_legacy,
361
+ "multimodal": simple_conv_multimodal,
362
+ "mpt_multimodal": simple_conv_mpt_multimodal,
363
+ "llava_v1": conv_llava_v1,
364
+
365
+ # fastchat
366
+ "v1": conv_v1_2,
367
+ "bair_v1": conv_bair_v1,
368
+ "vicuna_v1_1": conv_vicuna_v1_1,
369
+ "mpt": conv_mpt,
370
+ "mpt_text": conv_mpt_text,
371
+ }
372
+
373
+
374
+ if __name__ == "__main__":
375
+ print(default_conversation.get_prompt())
ThirdParty/PointLLM/pointllm/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .utils import load_objaverse_point_cloud, pc_norm, farthest_point_sample
2
+ from .object_point_dataset import ObjectPointCloudDataset, make_object_point_data_module
3
+ from .modelnet import ModelNet
ThirdParty/PointLLM/pointllm/data/modelnet.py ADDED
@@ -0,0 +1,147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import numpy as np
4
+ import pickle
5
+ from torch.utils.data import Dataset
6
+ from pointllm.utils import *
7
+ from pointllm.data.utils import *
8
+
9
+ class ModelNet(Dataset):
10
+ def __init__(self, config_path, split, subset_nums=-1, use_color=False):
11
+ """
12
+ Args:
13
+ data_args:
14
+ split: train or test
15
+ """
16
+ super(ModelNet, self).__init__()
17
+
18
+ if config_path is None:
19
+ # * use the default config file in the same dir
20
+ config_path = os.path.join(os.path.dirname(__file__), "modelnet_config", "ModelNet40.yaml")
21
+
22
+ config = cfg_from_yaml_file(config_path)
23
+ # * check data path
24
+ self.root = config["DATA_PATH"]
25
+
26
+ if not os.path.exists(self.root):
27
+ print(f"Data path {self.root} does not exist. Please check your data path.")
28
+ exit()
29
+
30
+ self.npoints = config.npoints
31
+ self.num_category = config.NUM_CATEGORY # * should be 40
32
+ self.random_sample = config.random_sampling
33
+ self.use_height = config.use_height
34
+ self.use_normals = config.USE_NORMALS
35
+ self.subset_nums = subset_nums
36
+ self.normalize_pc = True
37
+ self.use_color = use_color
38
+
39
+ if self.use_height or self.use_normals:
40
+ print(f"Warning: Usually we don't use height or normals for shapenet but use_height: {self.use_height} and \
41
+ use_normals: {self.use_normals}.")
42
+
43
+ self.split = split
44
+ assert (self.split == 'train' or self.split == 'test')
45
+
46
+ self.catfile = os.path.join(os.path.dirname(__file__), "modelnet_config", 'modelnet40_shape_names_modified.txt')
47
+
48
+ # "tv_stand" -> "tv stand"
49
+ self.categories = [line.rstrip() for line in open(self.catfile)] # * list of category names
50
+
51
+ self.save_path = os.path.join(self.root,
52
+ 'modelnet%d_%s_%dpts_fps.dat' % (self.num_category, self.split, self.npoints))
53
+
54
+ print('Load processed data from %s...' % self.save_path)
55
+ with open(self.save_path, 'rb') as f:
56
+ self.list_of_points, self.list_of_labels = pickle.load(f) # * ndarray of N, C: (8192, 6) (xyz and normals)
57
+
58
+ if self.subset_nums > 0:
59
+ # * set random seed
60
+ import random
61
+ random.seed(0)
62
+ # * random choose subset_nums
63
+ idxs = random.sample(range(len(self.list_of_labels)), self.subset_nums)
64
+ self.list_of_labels = [self.list_of_labels[idx] for idx in idxs]
65
+ self.list_of_points = [self.list_of_points[idx] for idx in idxs]
66
+
67
+ # * print len
68
+ print(f"Load {len(self.list_of_points)} data from {self.save_path}.")
69
+
70
+ def __len__(self):
71
+ return len(self.list_of_labels)
72
+
73
+ def _get_item(self, index):
74
+ point_set, label = self.list_of_points[index], self.list_of_labels[index]
75
+
76
+ if self.npoints < point_set.shape[0]:
77
+ if self.random_sample:
78
+ # * random sample
79
+ point_set = point_set[np.random.choice(point_set.shape[0], self.npoints, replace=False)]
80
+ else:
81
+ point_set = farthest_point_sample(point_set, self.npoints)
82
+
83
+ point_set[:, 0:3] = pc_normalize(point_set[:, 0:3])
84
+ if not self.use_normals:
85
+ point_set = point_set[:, 0:3]
86
+
87
+ if self.use_height:
88
+ self.gravity_dim = 1
89
+ height_array = point_set[:, self.gravity_dim:self.gravity_dim + 1] - point_set[:,
90
+ self.gravity_dim:self.gravity_dim + 1].min()
91
+ point_set = np.concatenate((point_set, height_array), axis=1)
92
+
93
+ point_set = np.concatenate((point_set, np.zeros_like(point_set)), axis=-1) if self.use_color else point_set
94
+
95
+ return point_set, label.item() # * ndarray, int
96
+
97
+ def pc_norm(self, pc):
98
+ """ pc: NxC, return NxC """
99
+ xyz = pc[:, :3]
100
+ other_feature = pc[:, 3:]
101
+
102
+ centroid = np.mean(xyz, axis=0)
103
+ xyz = xyz - centroid
104
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
105
+ xyz = xyz / m
106
+
107
+ pc = np.concatenate((xyz, other_feature), axis=1)
108
+ return pc
109
+
110
+ def __getitem__(self, index):
111
+ points, label = self._get_item(index)
112
+ pt_idxs = np.arange(0, points.shape[0]) # 2048
113
+ if self.split == 'train':
114
+ np.random.shuffle(pt_idxs)
115
+ current_points = points[pt_idxs].copy()
116
+
117
+ if self.normalize_pc:
118
+ # * modelnet point cloud is already normalized
119
+ current_points = self.pc_norm(current_points)
120
+
121
+ current_points = torch.from_numpy(current_points).float() # * N, C tensors
122
+ label_name = self.categories[int(label)]
123
+
124
+ data_dict = {
125
+ "indice": index, # * int
126
+ "point_clouds": current_points, # * tensor of N, C
127
+ "labels": label, # * int
128
+ "label_names": label_name # * str
129
+ }
130
+
131
+ return data_dict
132
+
133
+ if __name__ == '__main__':
134
+ import argparse
135
+
136
+ parser = argparse.ArgumentParser(description='ModelNet Dataset')
137
+
138
+ parser.add_argument("--config_path", type=str, default=None, help="config file path.")
139
+ parser.add_argument("--split", type=str, default="test", help="train or test.")
140
+ parser.add_argument("--subset_nums", type=int, default=200)
141
+
142
+ args = parser.parse_args()
143
+
144
+ dataset = ModelNet(config_path=args.config_path, split=args.split, subset_nums=args.subset_nums)
145
+
146
+ # * get the first item
147
+ print(dataset[0])
ThirdParty/PointLLM/pointllm/data/modelnet_config/ModelNet40.yaml ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ NAME: ModelNet
2
+ DATA_PATH: data/modelnet40_data
3
+ NUM_CATEGORY: 40
4
+ USE_NORMALS: FALSE
5
+ npoints: 8192
6
+ random_sampling: TRUE
7
+ use_height: FALSE
8
+ use_normals: FALSE
ThirdParty/PointLLM/pointllm/data/object_point_dataset.py ADDED
@@ -0,0 +1,250 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import torch
4
+ import numpy as np
5
+
6
+ import copy
7
+ import transformers
8
+ from torch.utils.data import Dataset
9
+
10
+ from .utils import *
11
+
12
+
13
+ def make_object_point_data_module(tokenizer: transformers.PreTrainedTokenizer, data_args) -> Dict:
14
+ """Make dataset and collator for Joint3Ddataset with text and point cloud data."""
15
+ """Initialize datasets."""
16
+
17
+ data_collator = DataCollatorForPointTextDataset(tokenizer=tokenizer)
18
+ if data_args.split_train_val:
19
+ print("Loading training datasets.")
20
+ train_dataset = ObjectPointCloudDataset(
21
+ split='train',
22
+ data_path=data_args.data_path,
23
+ anno_path=data_args.anno_path,
24
+ pointnum=data_args.pointnum,
25
+ conversation_types=data_args.conversation_types,
26
+ tokenizer=tokenizer,
27
+ use_color=data_args.use_color,
28
+ data_args=data_args
29
+ )
30
+ print("Done!")
31
+ if data_args.data_debug_num > 0:
32
+ print('Debug mode, using training set as val set.')
33
+ val_dataset = train_dataset
34
+ else:
35
+ # * make a val dataset
36
+ print("Loading validation datasets.")
37
+ val_dataset = ObjectPointCloudDataset(
38
+ split='val', # * load train split
39
+ data_path=data_args.data_path,
40
+ anno_path=data_args.anno_path,
41
+ pointnum=data_args.pointnum,
42
+ conversation_types=data_args.conversation_types,
43
+ tokenizer=tokenizer,
44
+ use_color=data_args.use_color,
45
+ data_args=data_args
46
+ )
47
+ return dict(train_dataset=train_dataset, eval_dataset=val_dataset, data_collator=data_collator)
48
+ else:
49
+ # * use all data as training data
50
+ train_dataset = ObjectPointCloudDataset(
51
+ split='train',
52
+ data_path=data_args.data_path,
53
+ anno_path=data_args.anno_path,
54
+ pointnum=data_args.pointnum,
55
+ conversation_types=data_args.conversation_types,
56
+ use_color=data_args.use_color,
57
+ tokenizer=tokenizer,
58
+ data_args=data_args
59
+ )
60
+ return dict(train_dataset=train_dataset, eval_dataset=None, data_collator=data_collator)
61
+
62
+ class ObjectPointCloudDataset(Dataset):
63
+ """Dataset utilities for objaverse."""
64
+ def __init__(self,
65
+ data_path=None,
66
+ anno_path=None,
67
+ tokenizer=None,
68
+ pointnum=8192,
69
+ split='train',
70
+ conversation_types=None, # * default is simple_des, used for stage1 pre-train
71
+ use_color=True,
72
+ data_args=None):
73
+
74
+ """
75
+ split: only considered when data_args.split_train_val is True.
76
+ conversation_types: tuple, used to filter the data, default is ('simple_description'), other types is:
77
+ "detailed_description", "single_round", "multi_round".
78
+ tokenizer: load point clouds only if None
79
+ """
80
+ super(ObjectPointCloudDataset, self).__init__()
81
+
82
+ """Initialize dataset with object point clouds and text"""
83
+ self.data_path = data_path
84
+ self.anno_path = anno_path
85
+ self.tokenizer = tokenizer
86
+ self.split = split
87
+ if conversation_types is None:
88
+ self.conversation_types = ("simple_description",)
89
+ else:
90
+ self.conversation_types = conversation_types
91
+
92
+ self.data_args = data_args
93
+ self.normalize_pc = True
94
+ self.use_color = use_color
95
+
96
+ self.pointnum = pointnum
97
+ self.point_backbone_config = data_args.point_backbone_config if data_args is not None else None
98
+ self.point_indicator = '<point>'
99
+
100
+ # Load the data list from JSON
101
+ print(f"Loading anno file from {anno_path}.")
102
+ with open(anno_path, "r") as json_file:
103
+ self.list_data_dict = json.load(json_file)
104
+
105
+ # * print the conversations_type
106
+ print(f"Using conversation_type: {self.conversation_types}")
107
+ # * print before filtering
108
+ print(f"Before filtering, the dataset size is: {len(self.list_data_dict)}.")
109
+
110
+ # * iterate the list and filter
111
+ # * these two ids have corrupted colored point files, so filter them when use_color is True
112
+ filter_ids = ['6760e543e1d645d5aaacd3803bcae524', 'b91c0711149d460a8004f9c06d3b7f38'] if self.use_color else []
113
+
114
+ # Iterate the list, filter those "conversation_type" not in self.conversation_types
115
+ self.list_data_dict = [
116
+ data for data in self.list_data_dict
117
+ if data.get('conversation_type', 'simple_description') in self.conversation_types
118
+ and data.get('object_id') not in filter_ids
119
+ ]
120
+
121
+ # * print after filtering
122
+ print(f"After filtering, the dataset size is: {len(self.list_data_dict)}.")
123
+ # * print the size of different conversation_type
124
+ for conversation_type in self.conversation_types:
125
+ print(f"Number of {conversation_type}: {len([data for data in self.list_data_dict if data.get('conversation_type', 'simple_description') == conversation_type])}")
126
+
127
+ if self.data_args is not None and self.data_args.data_debug_num > 0:
128
+ self.list_data_dict = self.list_data_dict[:self.data_args.data_debug_num]
129
+ # * print all the scan_id in debug mode, not using for loop
130
+ print('Debug mode, using: ' + ' '.join([data['object_id'] for data in self.list_data_dict]))
131
+ elif self.data_args is not None and self.data_args.split_train_val:
132
+ # * split train and val with 9:1 ratios
133
+ if self.split == 'train':
134
+ self.list_data_dict = self.list_data_dict[:int(self.data_args.split_ratio * len(self.list_data_dict))]
135
+ print(f"Train set size: {len(self.list_data_dict)}")
136
+ else:
137
+ self.list_data_dict = self.list_data_dict[int(self.data_args.split_ratio * len(self.list_data_dict)):]
138
+ print(f"Val set size: {len(self.list_data_dict)}")
139
+
140
+ def _load_point_cloud(self, object_id, type='objaverse'):
141
+ if type == 'objaverse':
142
+ return self._load_objaverse_point_cloud(object_id)
143
+
144
+ def _load_objaverse_point_cloud(self, object_id):
145
+ filename = f"{object_id}_{self.pointnum}.npy"
146
+ point_cloud = np.load(os.path.join(self.data_path, filename))
147
+
148
+ if not self.use_color:
149
+ point_cloud = point_cloud[:, :3]
150
+
151
+ return point_cloud
152
+
153
+ def pc_norm(self, pc):
154
+ """ pc: NxC, return NxC """
155
+ xyz = pc[:, :3]
156
+ other_feature = pc[:, 3:]
157
+
158
+ centroid = np.mean(xyz, axis=0)
159
+ xyz = xyz - centroid
160
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
161
+ xyz = xyz / m
162
+
163
+ pc = np.concatenate((xyz, other_feature), axis=1)
164
+ return pc
165
+
166
+ def __getitem__(self, index):
167
+ sources = self.list_data_dict[index]
168
+ if isinstance(index, int):
169
+ sources = [sources]
170
+ assert len(sources) == 1, "sources should be a list"
171
+ if self.point_indicator in sources[0]['conversations'][0]['value']:
172
+
173
+ object_id = self.list_data_dict[index]['object_id']
174
+
175
+ # Point cloud representation
176
+ point_cloud = self._load_point_cloud(object_id) # * N, C
177
+ if self.normalize_pc:
178
+ point_cloud = self.pc_norm(point_cloud) # * need to norm since point encoder is norm
179
+
180
+ if self.tokenizer is None:
181
+ data_dict = dict(
182
+ point_clouds=torch.from_numpy(point_cloud.astype(np.float32)),
183
+ object_ids=object_id
184
+ )
185
+ return data_dict
186
+
187
+ sources = preprocess_multimodal_point_cloud(
188
+ copy.deepcopy([e["conversations"] for e in sources]), self.point_backbone_config, point_indicator=self.point_indicator)
189
+ else:
190
+ sources = copy.deepcopy([e["conversations"] for e in sources])
191
+
192
+ data_dict = preprocess_v1(
193
+ sources,
194
+ self.tokenizer)
195
+
196
+ if isinstance(index, int):
197
+ data_dict = dict(input_ids=data_dict["input_ids"][0],
198
+ labels=data_dict["labels"][0])
199
+
200
+ # point exist in the data
201
+ if self.point_indicator in self.list_data_dict[index]['conversations'][0]['value']:
202
+ data_dict['point_clouds'] = torch.from_numpy(point_cloud.astype(np.float32))
203
+
204
+ return data_dict
205
+
206
+ def __len__(self):
207
+ """Return number of utterances."""
208
+ return len(self.list_data_dict)
209
+
210
+ if __name__ == '__main__':
211
+ import argparse
212
+ parser = argparse.ArgumentParser()
213
+
214
+ parser.add_argument("--data_path", default="data/objaverse_data", type=str,
215
+ help="Path to the data directory.")
216
+ parser.add_argument("--anno_path", default=None, type=str, required=True,
217
+ help="Path to the annotation file.")
218
+ parser.add_argument("--split", default='train', type=str,
219
+ help="Whether to use the train or validation dataset.")
220
+ parser.add_argument("--pointnum", default=8192, type=int,
221
+ help="Number of points in the point cloud.")
222
+ parser.add_argument("--data_debug_num", default=0, type=int,
223
+ help="Number of data to debug with.")
224
+ parser.add_argument("--split_train_val", default=False, type=bool,
225
+ help="Whether to split the dataset into training and validation.")
226
+ parser.add_argument("--split_ratio", default=0.9, type=float,
227
+ help="The ratio of training to validation data.")
228
+ parser.add_argument("--tokenizer_path", default=None, type=str, required=True,
229
+ help="Path to the tokenizer config file.")
230
+
231
+ args = parser.parse_args()
232
+
233
+ # Initialize tokenizer
234
+ tokenizer = transformers.AutoTokenizer.from_pretrained(args.tokenizer_path)
235
+
236
+ args.point_backbone_config = None
237
+
238
+ # Initialize dataset
239
+ dataset = ObjectPointCloudDataset(
240
+ data_path=args.data_path,
241
+ anno_path=args.anno_path,
242
+ pointnum=args.pointnum,
243
+ split=args.split,
244
+ tokenizer=tokenizer,
245
+ data_args=args
246
+ )
247
+
248
+ # Example usage
249
+ print(f'Dataset length: {len(dataset)}')
250
+
ThirdParty/PointLLM/pointllm/data/utils.py ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from collections import OrderedDict, defaultdict
2
+
3
+ import transformers
4
+ from pointllm import conversation as conversation_lib
5
+ from dataclasses import dataclass
6
+ from typing import Optional, Dict, Sequence
7
+ import torch
8
+
9
+ import numpy as np
10
+ import os
11
+
12
+ IGNORE_INDEX = -100
13
+
14
+ # * Sample Usage:
15
+ # * from utils import LRUCache
16
+ # * cache = LRUCache(capacity, max_access_count)
17
+ # if self.cache is None:
18
+ # info_data = self.multiview_scannet[info_index]
19
+ # else:
20
+ # info_data = self.cache.get(info_index)
21
+ # if info_data is None or self.cache.get_access_count(info_index) >= self.cache.max_access_count:
22
+ # # If not in cache, or accessed max_access_count times, load it and put it in cache
23
+ # info_data = self.multiview_scannet[info_index]
24
+ # self.cache.put(info_index, info_data)
25
+ # self.cache.reset_access_count(info_index)
26
+
27
+ class LRUCache:
28
+ def __init__(self, capacity, max_access_count):
29
+ self.cache = OrderedDict()
30
+ self.access_count = defaultdict(int)
31
+ self.capacity = capacity
32
+ self.max_access_count = max_access_count
33
+
34
+ def get(self, key):
35
+ if key not in self.cache:
36
+ return None
37
+ value = self.cache.pop(key)
38
+ self.cache[key] = value # Put key as the newest one
39
+ self.access_count[key] += 1
40
+ return value
41
+
42
+ def put(self, key, value):
43
+ if key in self.cache: # Update the value and put it as newest
44
+ self.cache.pop(key)
45
+ elif len(self.cache) == self.capacity: # If cache is full
46
+ oldest_key = next(iter(self.cache))
47
+ self.cache.popitem(last=False) # Remove oldest item
48
+ del self.access_count[oldest_key] # Remove the corresponding access count
49
+ self.cache[key] = value
50
+ self.access_count[key] = 1
51
+
52
+ def get_access_count(self, key):
53
+ return self.access_count.get(key, 0)
54
+
55
+ def reset_access_count(self, key):
56
+ self.access_count[key] = 0
57
+
58
+
59
+ def preprocess_v1(
60
+ sources,
61
+ tokenizer: transformers.PreTrainedTokenizer,
62
+ ) -> Dict:
63
+ conv = conversation_lib.default_conversation.copy()
64
+ roles = {"human": conv.roles[0], "gpt": conv.roles[1]}
65
+
66
+ # Apply prompt templates
67
+ conversations = []
68
+ for i, source in enumerate(sources):
69
+ if roles[source[0]["from"]] != conv.roles[0]:
70
+ # Skip the first one if it is not from human
71
+ source = source[1:]
72
+
73
+ conv.messages = []
74
+ for j, sentence in enumerate(source):
75
+ role = roles[sentence["from"]]
76
+ assert role == conv.roles[j % 2], f"{i}"
77
+ conv.append_message(role, sentence["value"])
78
+ conversations.append(conv.get_prompt())
79
+
80
+ # Tokenize conversations
81
+ input_ids = tokenizer(
82
+ conversations,
83
+ return_tensors="pt",
84
+ padding="longest",
85
+ max_length=tokenizer.model_max_length,
86
+ truncation=True,
87
+ ).input_ids
88
+ targets = input_ids.clone()
89
+
90
+ assert conv.sep_style == conversation_lib.SeparatorStyle.TWO
91
+
92
+ # Mask targets
93
+ sep = conv.sep + conv.roles[1] + ": "
94
+ for conversation, target in zip(conversations, targets):
95
+ total_len = int(target.ne(tokenizer.pad_token_id).sum())
96
+
97
+ rounds = conversation.split(conv.sep2)
98
+ cur_len = 1
99
+ target[:cur_len] = IGNORE_INDEX
100
+ for i, rou in enumerate(rounds):
101
+ if rou == "":
102
+ break
103
+
104
+ parts = rou.split(sep)
105
+ if len(parts) != 2: # * can handle padded tokens
106
+ break
107
+ parts[0] += sep
108
+ round_len = len(tokenizer(rou).input_ids)
109
+ instruction_len = len(tokenizer(parts[0]).input_ids) - 2
110
+
111
+ target[cur_len : cur_len + instruction_len] = IGNORE_INDEX
112
+
113
+ cur_len += round_len
114
+ target[cur_len:] = IGNORE_INDEX # * this is necessary for padded tokens
115
+
116
+ if cur_len < tokenizer.model_max_length:
117
+ if cur_len != total_len: # * unk tokens in the dialogue will cause this.
118
+ target[:] = IGNORE_INDEX
119
+ print(
120
+ f"WARNING: tokenization mismatch: {cur_len} vs. {total_len}."
121
+ f" (ignored)"
122
+ )
123
+
124
+ return dict(
125
+ input_ids=input_ids,
126
+ labels=targets,
127
+ )
128
+
129
+ def preprocess_multimodal_point_cloud(
130
+ sources: Sequence[str],
131
+ point_backbone_config: dict,
132
+ point_indicator: str = "<point>",
133
+ ) -> Dict:
134
+ point_token_len = point_backbone_config['point_token_len']
135
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
136
+
137
+ for source in sources:
138
+ for sentence in source:
139
+ replace_token = default_point_patch_token * point_token_len
140
+ if point_backbone_config['mm_use_point_start_end']:
141
+ replace_token = point_backbone_config['default_point_start_token']+ replace_token + point_backbone_config['default_point_end_token']
142
+ sentence["value"] = sentence["value"].replace(point_indicator, replace_token)
143
+
144
+ return sources
145
+
146
+ def pc_norm(pc):
147
+ """ pc: NxC, return NxC """
148
+ xyz = pc[:, :3]
149
+ other_feature = pc[:, 3:]
150
+
151
+ centroid = np.mean(xyz, axis=0)
152
+ xyz = xyz - centroid
153
+ m = np.max(np.sqrt(np.sum(xyz ** 2, axis=1)))
154
+ xyz = xyz / m
155
+
156
+ pc = np.concatenate((xyz, other_feature), axis=1)
157
+ return pc
158
+
159
+ def load_objaverse_point_cloud(data_path, object_id, pointnum=8192, use_color=False):
160
+ filename = f"{object_id}_{pointnum}.npy"
161
+ point_cloud = np.load(os.path.join(data_path, filename))
162
+
163
+ # * normalize
164
+ point_cloud = pc_norm(point_cloud)
165
+
166
+ if not use_color:
167
+ point_cloud = point_cloud[:, :3]
168
+
169
+ return point_cloud
170
+
171
+ @dataclass
172
+ class DataCollatorForPointTextDataset(object):
173
+ """Collate examples for mixed dataset with text and point cloud data."""
174
+
175
+ tokenizer: transformers.PreTrainedTokenizer
176
+
177
+ def __call__(self, instances: Sequence[Dict]) -> Dict[str, torch.Tensor]:
178
+ input_ids, labels = tuple([instance[key] for instance in instances]
179
+ for key in ("input_ids", "labels"))
180
+ input_ids = torch.nn.utils.rnn.pad_sequence(
181
+ input_ids,
182
+ batch_first=True,
183
+ padding_value=self.tokenizer.pad_token_id)
184
+ labels = torch.nn.utils.rnn.pad_sequence(labels,
185
+ batch_first=True,
186
+ padding_value=IGNORE_INDEX)
187
+ batch = dict(
188
+ input_ids=input_ids,
189
+ labels=labels,
190
+ attention_mask=input_ids.ne(self.tokenizer.pad_token_id),
191
+ )
192
+
193
+ if 'point_clouds' in instances[0]:
194
+ point_clouds = [instance['point_clouds'] for instance in instances]
195
+ if all(x is not None and x.shape == point_clouds[0].shape for x in point_clouds): # * point_clouds have different shapes
196
+ batch['point_clouds'] = torch.stack(point_clouds)
197
+ else:
198
+ batch['point_clouds'] = point_clouds # * return as lists
199
+
200
+ return batch
201
+
202
+ def farthest_point_sample(point, npoint):
203
+ """
204
+ Input:
205
+ xyz: pointcloud data, [N, D]
206
+ npoint: number of samples
207
+ Return:
208
+ centroids: sampled pointcloud index, [npoint, D]
209
+ """
210
+ N, D = point.shape
211
+ xyz = point[:,:3]
212
+ centroids = np.zeros((npoint,))
213
+ distance = np.ones((N,)) * 1e10
214
+ farthest = np.random.randint(0, N)
215
+ for i in range(npoint):
216
+ centroids[i] = farthest
217
+ centroid = xyz[farthest, :]
218
+ dist = np.sum((xyz - centroid) ** 2, -1)
219
+ mask = dist < distance
220
+ distance[mask] = dist[mask]
221
+ farthest = np.argmax(distance, -1)
222
+ point = point[centroids.astype(np.int32)]
223
+ return point
224
+
225
+ def pc_normalize(pc):
226
+ """
227
+ pc: Nx3 array
228
+ This functions normalizes a point cloud to fit within a unit sphere.
229
+ It first calculates the centroid of the point cloud and then subtracts
230
+ it from all points before scaling all points to fit within a unit sphere.
231
+ """
232
+ centroid = np.mean(pc, axis=0)
233
+ pc = pc - centroid
234
+ m = np.max(np.sqrt(np.sum(pc**2, axis=1)))
235
+ pc = pc / m
236
+ return pc
ThirdParty/PointLLM/pointllm/eval/PointLLM_chat.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import os
5
+ from pointllm.conversation import conv_templates, SeparatorStyle
6
+ from pointllm.utils import disable_torch_init
7
+ from pointllm.model import *
8
+ from pointllm.model.utils import KeywordsStoppingCriteria
9
+
10
+ from pointllm.data import load_objaverse_point_cloud
11
+
12
+ import os
13
+
14
+ def load_point_cloud(args):
15
+ object_id = args.object_id
16
+ print(f"[INFO] Loading point clouds using object_id: {object_id}")
17
+ point_cloud = load_objaverse_point_cloud(args.data_path, object_id, pointnum=8192, use_color=True)
18
+
19
+ return object_id, torch.from_numpy(point_cloud).unsqueeze_(0).to(torch.float32)
20
+
21
+ def init_model(args):
22
+ # Model
23
+ disable_torch_init()
24
+
25
+ model_path = args.model_path
26
+ print(f'[INFO] Model name: {model_path}')
27
+
28
+ tokenizer = AutoTokenizer.from_pretrained(model_path)
29
+ model = PointLLMLlamaForCausalLM.from_pretrained(model_path, low_cpu_mem_usage=False, use_cache=True, torch_dtype=args.torch_dtype).cuda()
30
+ model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer)
31
+
32
+ model.eval()
33
+
34
+ mm_use_point_start_end = getattr(model.config, "mm_use_point_start_end", False)
35
+ # Add special tokens ind to model.point_config
36
+ point_backbone_config = model.get_model().point_backbone_config
37
+
38
+ if mm_use_point_start_end:
39
+ if "v1" in model_path.lower():
40
+ conv_mode = "vicuna_v1_1"
41
+ else:
42
+ raise NotImplementedError
43
+
44
+ conv = conv_templates[conv_mode].copy()
45
+
46
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
47
+ keywords = [stop_str]
48
+
49
+ return model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv
50
+
51
+ def start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv):
52
+ point_token_len = point_backbone_config['point_token_len']
53
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
54
+ default_point_start_token = point_backbone_config['default_point_start_token']
55
+ default_point_end_token = point_backbone_config['default_point_end_token']
56
+ # The while loop will keep running until the user decides to quit
57
+ print("[INFO] Starting conversation... Enter 'q' to exit the program and enter 'exit' to exit the current conversation.")
58
+ while True:
59
+ print("-" * 80)
60
+ # Prompt for object_id
61
+ object_id = input("[INFO] Please enter the object_id or 'q' to quit: ")
62
+
63
+ # Check if the user wants to quit
64
+ if object_id.lower() == 'q':
65
+ print("[INFO] Quitting...")
66
+ break
67
+ else:
68
+ # print info
69
+ print(f"[INFO] Chatting with object_id: {object_id}.")
70
+
71
+ # Update args with new object_id
72
+ args.object_id = object_id.strip()
73
+
74
+ # Load the point cloud data
75
+ try:
76
+ id, point_clouds = load_point_cloud(args)
77
+ except Exception as e:
78
+ print(f"[ERROR] {e}")
79
+ continue
80
+ point_clouds = point_clouds.cuda().to(args.torch_dtype)
81
+
82
+ # Reset the conversation template
83
+ conv.reset()
84
+
85
+ print("-" * 80)
86
+
87
+ # Start a loop for multiple rounds of dialogue
88
+ for i in range(100):
89
+ # This if-else block ensures the initial question from the user is included in the conversation
90
+ qs = input(conv.roles[0] + ': ')
91
+ if qs == 'exit':
92
+ break
93
+
94
+ if i == 0:
95
+ if mm_use_point_start_end:
96
+ qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs
97
+ else:
98
+ qs = default_point_patch_token * point_token_len + '\n' + qs
99
+
100
+ # Append the new message to the conversation history
101
+ conv.append_message(conv.roles[0], qs)
102
+ conv.append_message(conv.roles[1], None)
103
+ prompt = conv.get_prompt()
104
+ inputs = tokenizer([prompt])
105
+
106
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
107
+
108
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
109
+ stop_str = keywords[0]
110
+
111
+ with torch.inference_mode():
112
+ output_ids = model.generate(
113
+ input_ids,
114
+ point_clouds=point_clouds,
115
+ do_sample=True,
116
+ temperature=1.0,
117
+ top_k=50,
118
+ max_length=2048,
119
+ top_p=0.95,
120
+ stopping_criteria=[stopping_criteria])
121
+
122
+ input_token_len = input_ids.shape[1]
123
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
124
+ if n_diff_input_output > 0:
125
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
126
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
127
+ outputs = outputs.strip()
128
+ if outputs.endswith(stop_str):
129
+ outputs = outputs[:-len(stop_str)]
130
+ outputs = outputs.strip()
131
+
132
+ # Append the model's response to the conversation history
133
+ conv.pop_last_none_message()
134
+ conv.append_message(conv.roles[1], outputs)
135
+ print(f'{conv.roles[1]}: {outputs}\n')
136
+
137
+ if __name__ == "__main__":
138
+ parser = argparse.ArgumentParser()
139
+ parser.add_argument("--model_name", type=str, \
140
+ default="RunsenXu/PointLLM_7B_v1.2")
141
+
142
+ parser.add_argument("--data_path", type=str, default="data/objaverse_data")
143
+ parser.add_argument("--torch_dtype", type=str, default="float32", choices=["float32", "float16", "bfloat16"])
144
+
145
+ args = parser.parse_args()
146
+
147
+ dtype_mapping = {
148
+ "float32": torch.float32,
149
+ "float16": torch.float16,
150
+ "bfloat16": torch.bfloat16,
151
+ }
152
+
153
+ args.torch_dtype = dtype_mapping[args.torch_dtype]
154
+
155
+ model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args)
156
+
157
+ start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv)
ThirdParty/PointLLM/pointllm/eval/chat_gradio.py ADDED
@@ -0,0 +1,394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ from transformers import AutoTokenizer
3
+ import torch
4
+ import os
5
+ from pointllm.conversation import conv_templates, SeparatorStyle
6
+ from pointllm.utils import disable_torch_init
7
+ from pointllm.model import *
8
+ from pointllm.model.utils import KeywordsStoppingCriteria
9
+ import numpy as np
10
+
11
+ from pointllm.data import pc_norm, farthest_point_sample
12
+
13
+ import os
14
+
15
+ # Additional import for gradio
16
+ import gradio as gr
17
+ import open3d as o3d
18
+ import plotly.graph_objects as go
19
+ import objaverse
20
+ import time
21
+
22
+ import logging
23
+
24
+
25
+ def change_input_method(input_method):
26
+ if input_method == 'File':
27
+ result = [gr.update(visible=True),
28
+ gr.update(visible=False)]
29
+ elif input_method == 'Object ID':
30
+ result = [gr.update(visible=False),
31
+ gr.update(visible=True)]
32
+ return result
33
+
34
+ def init_model(args):
35
+ # Model
36
+ disable_torch_init()
37
+ model_name = os.path.expanduser(args.model_name)
38
+
39
+ # * print the model_name (get the basename)
40
+ print(f'[INFO] Model name: {os.path.basename(model_name)}')
41
+ logging.warning(f'Model name: {os.path.basename(model_name)}')
42
+
43
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
44
+ model = PointLLMLlamaForCausalLM.from_pretrained(model_name, low_cpu_mem_usage=False, use_cache=True).cuda()
45
+ model.initialize_tokenizer_point_backbone_config_wo_embedding(tokenizer)
46
+
47
+ model.eval()
48
+
49
+ mm_use_point_start_end = getattr(model.config, "mm_use_point_start_end", False)
50
+ # Add special tokens ind to model.point_config
51
+ point_backbone_config = model.get_model().point_backbone_config
52
+
53
+ conv = conv_templates["vicuna_v1_1"].copy()
54
+
55
+ stop_str = conv.sep if conv.sep_style != SeparatorStyle.TWO else conv.sep2
56
+ keywords = [stop_str]
57
+
58
+ return model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv
59
+
60
+ def start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv):
61
+ point_token_len = point_backbone_config['point_token_len']
62
+ default_point_patch_token = point_backbone_config['default_point_patch_token']
63
+ default_point_start_token = point_backbone_config['default_point_start_token']
64
+ default_point_end_token = point_backbone_config['default_point_end_token']
65
+
66
+ # The while loop will keep running until the user decides to quit
67
+ print("[INFO] Starting conversation...")
68
+ logging.warning("Starting conversation...")
69
+ while True:
70
+ print("-" * 80)
71
+ logging.warning("-" * 80)
72
+
73
+ # Reset the conversation template
74
+ conv.reset()
75
+
76
+ def confirm_point_cloud(input_choice, object_id_input, point_cloud_input, chatbot, answer_time, conv):
77
+ objects = None
78
+ data = None
79
+ object_id_input = object_id_input.strip()
80
+
81
+ print("%" * 80)
82
+ logging.warning("%" * 80)
83
+
84
+ if input_choice == 'File':
85
+ file = point_cloud_input.name
86
+ print(f"Uploading file: {file}.")
87
+ logging.warning(f"Uploading file: {file}.")
88
+ elif input_choice == 'Object ID':
89
+ file = os.path.join(args.data_path, "{}_8192.npy".format(object_id_input))
90
+ print(f"Object_id: {object_id_input}")
91
+ logging.warning(f"Object_id: {object_id_input}")
92
+
93
+ object_uids = [object_id_input]
94
+ objects = objaverse.load_objects(uids=object_uids)
95
+ print("%" * 80)
96
+ logging.warning("%" * 80)
97
+
98
+ manual_no_color = "no_color" in file
99
+
100
+ try:
101
+ if '.ply' in file:
102
+ pcd = o3d.io.read_point_cloud(file)
103
+ points = np.asarray(pcd.points) # xyz
104
+ colors = np.asarray(pcd.colors) # rgb, if available
105
+ # * if no colors actually, empty array
106
+ if colors.size == 0:
107
+ colors = None
108
+ elif '.npy' in file:
109
+ data = np.load(file)
110
+ if data.shape[1] >= 3:
111
+ points = data[:, :3]
112
+ else:
113
+ raise ValueError("Input array has the wrong shape. Expected: [N, 3]. Got: {}.".format(data.shape))
114
+ colors = None if data.shape[1] < 6 else data[:, 3:6]
115
+ else:
116
+ raise ValueError("Not supported data format.")
117
+ # error
118
+ except Exception as e:
119
+ print(f"[ERROR] {e}")
120
+ logging.warning(f"[ERROR] {e}")
121
+
122
+ chatbot_system_message = "Sorry. The Objaverse id is not supported or the uploaded file has something wrong!"
123
+ print(f"[ChatBot System Message]: {chatbot_system_message}")
124
+ logging.warning(f"[ChatBot System Message]: {chatbot_system_message}")
125
+
126
+ outputs = f"<span style='color: red;'>[System] {chatbot_system_message}</span>" # "You upload a new Points Cloud"
127
+ chatbot = chatbot + [[None, outputs]]
128
+
129
+ return None, None, chatbot, answer_time, None
130
+
131
+ if manual_no_color:
132
+ colors = None
133
+
134
+ if colors is not None:
135
+ # * if colors in range(0-1)
136
+ if np.max(colors) <= 1:
137
+ color_data = np.multiply(colors, 255).astype(int) # Convert float values (0-1) to integers (0-255)
138
+ # * if colors in range(0-255)
139
+ elif np.max(colors) <= 255:
140
+ color_data = colors.astype(int)
141
+ else:
142
+ color_data = np.zeros_like(points).astype(int) # Default to black color if RGB information is not available
143
+ colors = color_data.astype(np.float32) / 255 # model input is (0-1)
144
+
145
+ # Convert the RGB color data to a list of RGB strings in the format 'rgb(r, g, b)'
146
+ color_strings = ['rgb({},{},{})'.format(r, g, b) for r, g, b in color_data]
147
+
148
+ fig = go.Figure(
149
+ data=[
150
+ go.Scatter3d(
151
+ x=points[:, 0], y=points[:, 1], z=points[:, 2],
152
+ mode='markers',
153
+ marker=dict(
154
+ size=1.2,
155
+ color=color_strings, # Use the list of RGB strings for the marker colors
156
+ )
157
+ )
158
+ ],
159
+ layout=dict(
160
+ scene=dict(
161
+ xaxis=dict(visible=False),
162
+ yaxis=dict(visible=False),
163
+ zaxis=dict(visible=False)
164
+ ),
165
+ paper_bgcolor='rgb(255,255,255)' # Set the background color to dark gray 50, 50, 50
166
+ ),
167
+ )
168
+
169
+ points = np.concatenate((points, colors), axis=1)
170
+ if 8192 < points.shape[0]:
171
+ points = farthest_point_sample(points, 8192)
172
+ point_clouds = pc_norm(points)
173
+ point_clouds = torch.from_numpy(point_clouds).unsqueeze_(0).to(torch.float32).cuda()
174
+
175
+ answer_time = 0
176
+ conv.reset()
177
+
178
+ outputs = "<span style='color: red;'>[System] New Point Cloud</span>"
179
+ chatbot = chatbot + [[None, outputs]]
180
+
181
+ return fig, list(objects.values())[0] if objects is not None else None, chatbot, answer_time, point_clouds
182
+
183
+ def answer_generate(history, answer_time, point_clouds, conv):
184
+ if point_clouds is None:
185
+ outputs = "<span style='color: red;'>[System] Please input point cloud! </span>"
186
+ history[-1][1] = outputs
187
+ yield history
188
+ else:
189
+ print(f"Answer Time: {answer_time}")
190
+ logging.warning(f"Answer Time: {answer_time}")
191
+ input_text = history[-1][0]
192
+ qs = input_text
193
+
194
+ if answer_time == 0:
195
+ if mm_use_point_start_end:
196
+ qs = default_point_start_token + default_point_patch_token * point_token_len + default_point_end_token + '\n' + qs
197
+ else:
198
+ qs = default_point_patch_token * point_token_len + '\n' + qs
199
+
200
+ # Append the new message to the conversation history
201
+ conv.append_message(conv.roles[0], qs)
202
+ conv.append_message(conv.roles[1], None)
203
+ prompt = conv.get_prompt()
204
+ print("#" * 80)
205
+ print(f'{prompt.replace("<point_patch>" * point_token_len, f"<point_patch> * {point_token_len}")}') # for concise printing
206
+ print("#" * 80)
207
+
208
+ logging.warning("#" * 80)
209
+ logging.warning(f'{prompt.replace("<point_patch>" * point_token_len, f"<point_patch> * {point_token_len}")}') # for concise printing
210
+ logging.warning("#" * 80)
211
+ inputs = tokenizer([prompt])
212
+
213
+ input_ids = torch.as_tensor(inputs.input_ids).cuda()
214
+
215
+ stopping_criteria = KeywordsStoppingCriteria(keywords, tokenizer, input_ids)
216
+ stop_str = keywords[0]
217
+
218
+ try:
219
+ if input_ids.shape[1] >= 2047:
220
+ raise ValueError("Current context length exceeds the maximum context length (2048) of the model.")
221
+ with torch.inference_mode():
222
+ output_ids = model.generate(
223
+ input_ids,
224
+ point_clouds=point_clouds,
225
+ do_sample=True,
226
+ temperature=1.0,
227
+ top_k=50,
228
+ max_length=2048,
229
+ top_p=0.95,
230
+ stopping_criteria=[stopping_criteria])
231
+
232
+ input_token_len = input_ids.shape[1]
233
+ n_diff_input_output = (input_ids != output_ids[:, :input_token_len]).sum().item()
234
+ if n_diff_input_output > 0:
235
+ print(f'[Warning] {n_diff_input_output} output_ids are not the same as the input_ids')
236
+ logging.warning(f'{n_diff_input_output} output_ids are not the same as the input_ids')
237
+ outputs = tokenizer.batch_decode(output_ids[:, input_token_len:], skip_special_tokens=True)[0]
238
+ outputs = outputs.strip()
239
+ if outputs.endswith(stop_str):
240
+ outputs = outputs[:-len(stop_str)]
241
+ outputs = outputs.strip()
242
+
243
+ # Append the model's response to the conversation history
244
+ conv.pop_last_none_message()
245
+ conv.append_message(conv.roles[1], outputs)
246
+ print(f'{conv.roles[1]}: {outputs}\n')
247
+ logging.warning(f'{conv.roles[1]}: {outputs}\n')
248
+ answer_time += 1
249
+ history[-1][1] = ""
250
+ for character in outputs:
251
+ history[-1][1] += character
252
+ yield history
253
+ # error
254
+ except Exception as e:
255
+ print(f"[ERROR] {e}")
256
+ logging.warning(f"[ERROR] {e}")
257
+
258
+ if input_ids.shape[1] >= 2047:
259
+ chatbot_system_message = "Current context length exceeds the maximum context length (2048) of the model. Please press 'Clear' to restart."
260
+ else:
261
+ chatbot_system_message = "Sorry. There is something wrong when generating. Please check the your uploaded point cloud or the Objaverse id, and \
262
+ confirm the point cloud again."
263
+ print(f"[ChatBot System Message]: {chatbot_system_message}")
264
+ logging.warning(f"[ChatBot System Message]: {chatbot_system_message}")
265
+
266
+ outputs = f"<span style='color: red;'>[System] {chatbot_system_message}</span>" # "You upload a new Points Cloud"
267
+ history[-1][1] = outputs
268
+ yield history
269
+
270
+ with gr.Blocks() as demo:
271
+ answer_time = gr.State(value=0)
272
+ point_clouds = gr.State(value=None)
273
+ conv_state = gr.State(value=conv.copy())
274
+ gr.Markdown(
275
+ """
276
+ # PointLLM: Empowering Large Language Models to Understand Point Clouds. 🚀
277
+ If you think this demo interesting, please consider starring 🌟 our github repo. :)
278
+ [[Project Page](https://runsenxu.com/projects/PointLLM)] [[Paper](https://arxiv.org/abs/2308.16911)] [[Code](https://github.com/OpenRobotLab/PointLLM)]
279
+ """
280
+ )
281
+ with gr.Row():
282
+ with gr.Column():
283
+ input_choice = gr.Radio(['File', 'Object ID'], value='Object ID', interactive=True, label='Input Method', info="How do you want to load point clouds?")
284
+ object_id_input = gr.Textbox(visible = True,lines=1, label='Object ID Input')
285
+ point_cloud_input = gr.File(visible = False, label="Upload Point Cloud File (PLY, NPY)")
286
+ output = gr.Plot()
287
+ btn = gr.Button(value="Confirm Point Cloud")
288
+ model3D = gr.Model3D()
289
+ with gr.Column():
290
+ chatbot = gr.Chatbot([], elem_id="chatbot", height=560) # ,color_map=("green", "pink")
291
+
292
+ def user(user_message, history):
293
+ return "", history + [[user_message, None]]
294
+
295
+ def clear_conv(history, conv):
296
+ conv.reset()
297
+ return None, 0
298
+
299
+ with gr.Row():
300
+ text_input = gr.Textbox(
301
+ show_label=False,
302
+ placeholder="Enter text and press enter",
303
+ container=False,
304
+ )
305
+ run_button = gr.Button("Send")
306
+
307
+ clear = gr.Button("Clear")
308
+ text_input.submit(user, [text_input, chatbot], [text_input, chatbot], queue=False).then(answer_generate, [chatbot, answer_time, point_clouds, conv_state], chatbot).then(lambda x : x+1,answer_time, answer_time)
309
+ clear.click(clear_conv, inputs=[chatbot, conv_state], outputs=[chatbot, answer_time], queue=False)
310
+
311
+ btn.click(confirm_point_cloud, inputs=[input_choice, object_id_input, point_cloud_input, chatbot, answer_time, conv_state], outputs=[output, model3D, chatbot, answer_time, point_clouds])
312
+
313
+ input_choice.change(change_input_method, input_choice, [point_cloud_input, object_id_input])
314
+ run_button.click(user, [text_input, chatbot], [text_input, chatbot], queue=False).then(answer_generate, [chatbot, answer_time, point_clouds, conv_state], chatbot).then(lambda x : x+1, answer_time, answer_time)
315
+
316
+ gr.Markdown(
317
+ """
318
+ ### Usage:
319
+ 1. Upload your point cloud file (ply, npy only) or input the supported [Objaverse object id (uid)](https://drive.google.com/file/d/1gLwA7aHfy1KCrGeXlhICG9rT2387tWY8/view?usp=sharing) (currently 660K objects only, you may try the example object ids below).
320
+ 2. If your point cloud file does not contian colors, manually set the file name contains 'no_color' (e.g., 'xxx_no_color.npy'), and the black color will be assigned.
321
+ 3. If uploading your own point cloud file with color in npy format, the first three dimensions should be xyz, and the next three dimensions should be rgb. The rgb values should range from **0 to 1**.
322
+ 4. Click **Confirm Point Cloud**.
323
+ 5. As we use FPS sampling to downsample the point cloud to 8192 points, it may take a long time to confirm the point cloud if the point cloud has too many points. You may use random sampling to downsample the point cloud before uploading.
324
+ 6. Once '[System] New Point Cloud' appears in the dialogue box, a new conversation with PointLLM is initialized.
325
+ 7. The 'Clear' button will clear the conversation history.
326
+ """)
327
+ with gr.Accordion("Example Objaverse object ids in the validation set!", open=False):
328
+ example_object_ids = [ ["b4bbf2116b1a41a5a3b9d3622b07074c", "0b8da82a3d7a436f9b585436c4b72f56", "650c53d68d374c18886aab91bcf8bb54"],
329
+ ["983fa8b23a084f5dacd157e6c9ceba97", "8fe23dd4bf8542b49c3a574b33e377c3", "83cb2a9e9afb47cd9f45461613796645"],
330
+ ["3d679a3888c548afb8cf889915af7fd2", "7bcf8626eaca40e592ffd0aed08aa30b", "69865c89fc7344be8ed5c1a54dbddc20"],
331
+ ["252f3b3f5cd64698826fc1ab42614677", "e85ebb729b02402bbe3b917e1196f8d3", "97367c4740f64935b7a5e34ae1398035"],
332
+ ["fc8dd5a2fc9f4dd19ad6a64a8a6e89e9", "8257772b0e2f408ba269264855dfea00", "d6a3520486bb474f9b5e72eda8408974"],
333
+ ["3d10918e6a9a4ad395a7280c022ad2b9", "00002bcb84af4a4781174e62619f14e2", "76ba80230d454de996878c2763fe7e5c"]]
334
+ gr.DataFrame(
335
+ type="array",
336
+ headers=["Example Object IDs"] * 3,
337
+ row_count=6,
338
+ col_count=3,
339
+ value=example_object_ids
340
+ )
341
+ gr.Markdown(
342
+ """
343
+ #### Terms of use
344
+ By using this service, users are required to agree to the following terms: The service is a research preview intended for non-commercial use only. It only provides limited safety measures and may generate offensive content. It must not be used for any illegal, harmful, violent, racist, or sexual purposes. The service may collect user dialogue data for future research.
345
+ """
346
+ )
347
+ demo.queue()
348
+ demo.launch(server_name="0.0.0.0", server_port=args.port, share=False) # server_port=7832, share=True
349
+
350
+ if __name__ == "__main__":
351
+ # ! To release this demo in public, make sure to start in a place where no important data is stored.
352
+ # ! Please check 1. the lanuch dir 2. the tmp dir (GRADIO_TEMP_DIR)
353
+ # ! refer to https://www.gradio.app/guides/sharing-your-app#security-and-file-access
354
+ parser = argparse.ArgumentParser()
355
+ parser.add_argument("--model-name", type=str, \
356
+ default="RunsenXu/PointLLM_7B_v1.2")
357
+
358
+
359
+ parser.add_argument("--data_path", type=str, default="data/objaverse_data", required=False)
360
+ parser.add_argument("--pointnum", type=int, default=8192)
361
+
362
+ parser.add_argument("--log_file", type=str, default="serving_workdirs/serving_log.txt")
363
+ parser.add_argument("--tmp_dir", type=str, default="serving_workdirs/tmp")
364
+
365
+ # For gradio
366
+ parser.add_argument("--port", type=int, default=7810)
367
+
368
+ args = parser.parse_args()
369
+
370
+ # * make serving dirs
371
+ os.makedirs(os.path.dirname(args.log_file), exist_ok=True)
372
+ os.makedirs(args.tmp_dir, exist_ok=True)
373
+
374
+ # * add the current time for log name
375
+ args.log_file = args.log_file.replace(".txt", f"_{time.strftime('%Y-%m-%d-%H-%M-%S', time.localtime())}.txt")
376
+
377
+ logging.basicConfig(
378
+ filename=args.log_file,
379
+ level=logging.WARNING, # * default gradio is info, so use warning
380
+ format='%(asctime)s - %(message)s',
381
+ datefmt='%Y-%m-%d %H:%M:%S'
382
+ )
383
+
384
+ logging.warning("-----New Run-----")
385
+ logging.warning(f"args: {args}")
386
+
387
+ print("-----New Run-----")
388
+ print(f"[INFO] Args: {args}")
389
+
390
+ # * set env variable GRADIO_TEMP_DIR to args.tmp_dir
391
+ os.environ["GRADIO_TEMP_DIR"] = args.tmp_dir
392
+
393
+ model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv = init_model(args)
394
+ start_conversation(args, model, tokenizer, point_backbone_config, keywords, mm_use_point_start_end, conv)