DECO / hot_analysis /create_part_probability_mesh.py
ac5113's picture
added files
99a05f0
raw
history blame
3.14 kB
import numpy as np
import os
import json
import trimesh
import seaborn as sns
# Load the combined dca train, val and test npzs
dir = '/is/cluster/work/stripathi/pycharm_remote/dca_contact/data/dataset_extras'
trainval_npz = np.load(os.path.join(dir, 'hot_dca_trainval.npz'), allow_pickle=True)
test_npz = np.load(os.path.join(dir, 'hot_dca_test.npz'), allow_pickle=True)
# combine the two npz
combined_npz = {}
for key in trainval_npz.keys():
combined_npz[key] = np.concatenate([trainval_npz[key], test_npz[key]], axis=0)
segmentation_path = 'data/smpl_vert_segmentation.json'
with open(segmentation_path, 'rb') as f:
part_segmentation = json.load(f)
combine_keys = {'leftFoot': ['leftToeBase'],
'rightFoot': ['rightToeBase'],
'leftHand': ['leftHandIndex1'],
'rightHand': ['rightHandIndex1'],
'spine': ['spine1', 'spine2'],
'head': ['neck'],}
for key in combine_keys:
for subkey in combine_keys[key]:
part_segmentation[key] += part_segmentation[subkey]
del part_segmentation[subkey]
# reverse the part segmentation
part_segmentation_rev = {}
for part in part_segmentation:
for vert in part_segmentation[part]:
part_segmentation_rev[vert] = part
# count the number of contact instances per vertex
per_vert_contact_count = np.zeros(6890)
for cls in combined_npz['contact_label']:
per_vert_contact_count += cls
# calculate the maximum contact count per part
part_contact_max = {}
for part in part_segmentation:
part_contact_max[part] = np.max(per_vert_contact_count[part_segmentation[part]])
# calculate the contact probability globally
contact_prob = np.zeros(6890)
for vid in range(6890):
contact_prob[vid] = (per_vert_contact_count[vid] / max(per_vert_contact_count)) ** 0.3
# save the contact probability mesh
outdir = "/is/cluster/work/stripathi/pycharm_remote/dca_contact/hot_analysis"
# load template smpl mesh
mesh = trimesh.load_mesh('data/smpl/smpl_neutral_tpose.ply')
vertex_colors = trimesh.visual.interpolate(contact_prob, 'jet')
# set the vertex colors of the mesh
mesh.visual.vertex_colors = vertex_colors
# save the mesh
out_path = os.path.join(outdir, "contact_probability_mesh.obj")
mesh.export(out_path)
# # calculate the contact probability per part
# contact_prob = np.zeros(6890)
# for vid in range(6890):
# if 'Hand' in part_segmentation_rev[vid]:
# contact_prob[vid] = (per_vert_contact_count[vid] / part_contact_max[part_segmentation_rev[vid]]) ** 0.4 if 'Hand' not in part_segmentation_rev[vid] else (per_vert_contact_count[vid] / part_contact_max[part_segmentation_rev[vid]]) ** 0.8
#
# # save the contact probability mesh
# outdir = "/is/cluster/work/stripathi/pycharm_remote/dca_contact/hot_analysis"
#
# # load template smpl mesh
# mesh = trimesh.load_mesh('data/smpl/smpl_neutral_tpose.ply')
# vertex_colors = trimesh.visual.interpolate(contact_prob, 'jet')
# # set the vertex colors of the mesh
# mesh.visual.vertex_colors = vertex_colors
# # save the mesh
# out_path = os.path.join(outdir, "contact_probability_mesh_part.obj")
# mesh.export(out_path)