File size: 4,338 Bytes
f9e76ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
# Visualisation code for SMPL-X model. This code is useful if you already have predictions.

import os
import sys
import os.path as osp
import numpy as np
import smplx
from smplx.joint_names import JOINT_NAMES
import torch
try:
    CUR_DIR = osp.dirname(os.path.abspath(__file__))
except NameError:
    CUR_DIR = os.getcwd()
sys.path.insert(0, osp.join(CUR_DIR, '..', 'main'))
sys.path.insert(0, osp.join(CUR_DIR , '..', 'common'))
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D

JOINT_NAMES_DICT = {name: i for i, name in enumerate(JOINT_NAMES)}

# Load the SMPL-X model
model_path = 'common/utils/human_model_files'  # Update with the path to your SMPL-X models
model = smplx.create(model_path, model_type='smplx', gender='neutral', ext='npz')

# Load the parameters from the .npz file
data = np.load('/home/sahand/Downloads/smplx/00047_9.npz')

betas = torch.tensor(data['betas'], dtype=torch.float32)
body_pose = torch.tensor(data['body_pose'], dtype=torch.float32)
global_orient = torch.tensor(data['global_orient'], dtype=torch.float32)
transl = torch.tensor(data['transl'], dtype=torch.float32)
expression = torch.tensor(data['expression'], dtype=torch.float32)

# Add missing dimensions to the tensors
if betas.ndim == 1:
    betas = betas.unsqueeze(0)
if body_pose.ndim == 2:
    body_pose = body_pose.unsqueeze(0)
if global_orient.ndim == 1:
    global_orient = global_orient.unsqueeze(0)
if transl.ndim == 1:
    transl = transl.unsqueeze(0)
if expression.ndim == 1:
    expression = expression.unsqueeze(0)

# Reshape body_pose to include the batch dimension
body_pose = body_pose.view(1, -1, 3)

# Forward pass through the model
output = model(betas=betas, body_pose=body_pose, global_orient=global_orient, transl=transl, expression=expression)

# Extract joint positions
joints = output.joints.detach().cpu().numpy().squeeze()
print(joints.shape)
# Ankle joints (left and right)
left_knee = joints[4]  # Index for left ankle in SMPL-X
right_knee = joints[5]  # Index for right ankle in SMPL-X
left_ankle = joints[7]  # Index for left ankle in SMPL-X
right_ankle = joints[8]  # Index for right ankle in SMPL-X

bone_connections = [
    (JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["spine1"]), (JOINT_NAMES_DICT["spine1"], JOINT_NAMES_DICT["spine2"]), (JOINT_NAMES_DICT["spine2"], JOINT_NAMES_DICT["spine3"]),  # Spine
    (JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["left_hip"]), (JOINT_NAMES_DICT["left_hip"], JOINT_NAMES_DICT["left_knee"]), (JOINT_NAMES_DICT["left_knee"], JOINT_NAMES_DICT["left_ankle"]),  # Left leg
    (JOINT_NAMES_DICT["pelvis"], JOINT_NAMES_DICT["right_hip"]), (JOINT_NAMES_DICT["right_hip"], JOINT_NAMES_DICT["right_knee"]), (JOINT_NAMES_DICT["right_knee"], JOINT_NAMES_DICT["right_ankle"]),  # Right leg
    (JOINT_NAMES_DICT["left_ankle"], JOINT_NAMES_DICT["left_heel"]), 
    (JOINT_NAMES_DICT["right_ankle"], JOINT_NAMES_DICT["right_heel"]), 
    (JOINT_NAMES_DICT["left_ankle"], JOINT_NAMES_DICT["left_foot"]), 
    (JOINT_NAMES_DICT["left_foot"], JOINT_NAMES_DICT["left_big_toe"]), (JOINT_NAMES_DICT["left_foot"], JOINT_NAMES_DICT["left_small_toe"]),
    (JOINT_NAMES_DICT["right_ankle"], JOINT_NAMES_DICT["right_foot"]), 
    (JOINT_NAMES_DICT["right_foot"], JOINT_NAMES_DICT["right_big_toe"]), (JOINT_NAMES_DICT["right_foot"], JOINT_NAMES_DICT["right_small_toe"]),
    # Add more bones if necessary
]

# Visualize the 3D skeleton
fig = plt.figure()
ax = fig.add_subplot(111, projection='3d')

# Plot all joints
ax.scatter(joints[:, 0], joints[:, 1], joints[:, 2], c='blue', marker='o')
# Highlight ankle joints
ax.scatter([left_knee[0]], [left_knee[1]], [left_knee[2]], c='red', marker='x', s=100, label='Left Knee')
ax.scatter([right_knee[0]], [right_knee[1]], [right_knee[2]], c='green', marker='x', s=100, label='Right Knee')
ax.scatter([left_ankle[0]], [left_ankle[1]], [left_ankle[2]], c='red', marker='o', s=100, label='Left Ankle')
ax.scatter([right_ankle[0]], [right_ankle[1]], [right_ankle[2]], c='green', marker='o', s=100, label='Right Ankle')

# Draw bones
for bone in bone_connections:
    start, end = bone
    ax.plot([joints[start, 0], joints[end, 0]],
            [joints[start, 1], joints[end, 1]],
            [joints[start, 2], joints[end, 2]], 'k-')

# Set labels
ax.set_xlabel('X')
ax.set_ylabel('Y')
ax.set_zlabel('Z')
ax.legend()
plt.show()