Spaces:
Sleeping
Sleeping
import numpy as np | |
import os | |
import json | |
import trimesh | |
import seaborn as sns | |
# Load the combined dca train, val and test npzs | |
dir = '/is/cluster/work/stripathi/pycharm_remote/dca_contact/data/dataset_extras' | |
trainval_npz = np.load(os.path.join(dir, 'hot_dca_trainval.npz'), allow_pickle=True) | |
test_npz = np.load(os.path.join(dir, 'hot_dca_test.npz'), allow_pickle=True) | |
# combine the two npz | |
combined_npz = {} | |
for key in trainval_npz.keys(): | |
combined_npz[key] = np.concatenate([trainval_npz[key], test_npz[key]], axis=0) | |
segmentation_path = 'data/smpl_vert_segmentation.json' | |
with open(segmentation_path, 'rb') as f: | |
part_segmentation = json.load(f) | |
combine_keys = {'leftFoot': ['leftToeBase'], | |
'rightFoot': ['rightToeBase'], | |
'leftHand': ['leftHandIndex1'], | |
'rightHand': ['rightHandIndex1'], | |
'spine': ['spine1', 'spine2'], | |
'head': ['neck'],} | |
for key in combine_keys: | |
for subkey in combine_keys[key]: | |
part_segmentation[key] += part_segmentation[subkey] | |
del part_segmentation[subkey] | |
# reverse the part segmentation | |
part_segmentation_rev = {} | |
for part in part_segmentation: | |
for vert in part_segmentation[part]: | |
part_segmentation_rev[vert] = part | |
# count the number of contact instances per vertex | |
per_vert_contact_count = np.zeros(6890) | |
for cls in combined_npz['contact_label']: | |
per_vert_contact_count += cls | |
# calculate the maximum contact count per part | |
part_contact_max = {} | |
for part in part_segmentation: | |
part_contact_max[part] = np.max(per_vert_contact_count[part_segmentation[part]]) | |
# calculate the contact probability globally | |
contact_prob = np.zeros(6890) | |
for vid in range(6890): | |
contact_prob[vid] = (per_vert_contact_count[vid] / max(per_vert_contact_count)) ** 0.3 | |
# save the contact probability mesh | |
outdir = "/is/cluster/work/stripathi/pycharm_remote/dca_contact/hot_analysis" | |
# load template smpl mesh | |
mesh = trimesh.load_mesh('data/smpl/smpl_neutral_tpose.ply') | |
vertex_colors = trimesh.visual.interpolate(contact_prob, 'jet') | |
# set the vertex colors of the mesh | |
mesh.visual.vertex_colors = vertex_colors | |
# save the mesh | |
out_path = os.path.join(outdir, "contact_probability_mesh.obj") | |
mesh.export(out_path) | |
# # calculate the contact probability per part | |
# contact_prob = np.zeros(6890) | |
# for vid in range(6890): | |
# if 'Hand' in part_segmentation_rev[vid]: | |
# contact_prob[vid] = (per_vert_contact_count[vid] / part_contact_max[part_segmentation_rev[vid]]) ** 0.4 if 'Hand' not in part_segmentation_rev[vid] else (per_vert_contact_count[vid] / part_contact_max[part_segmentation_rev[vid]]) ** 0.8 | |
# | |
# # save the contact probability mesh | |
# outdir = "/is/cluster/work/stripathi/pycharm_remote/dca_contact/hot_analysis" | |
# | |
# # load template smpl mesh | |
# mesh = trimesh.load_mesh('data/smpl/smpl_neutral_tpose.ply') | |
# vertex_colors = trimesh.visual.interpolate(contact_prob, 'jet') | |
# # set the vertex colors of the mesh | |
# mesh.visual.vertex_colors = vertex_colors | |
# # save the mesh | |
# out_path = os.path.join(outdir, "contact_probability_mesh_part.obj") | |
# mesh.export(out_path) | |