|
import torch |
|
|
|
|
|
def to3d(poses, config): |
|
if config.Data.pose.convert_to_6d: |
|
if config.Data.pose.expression: |
|
poses_exp = poses[:, -100:] |
|
poses = poses[:, :-100] |
|
|
|
poses = poses.reshape(poses.shape[0], -1, 5) |
|
sin, cos = poses[:, :, 3], poses[:, :, 4] |
|
pose_angle = torch.atan2(sin, cos) |
|
poses = (poses[:, :, :3] * pose_angle.unsqueeze(dim=-1)).reshape(poses.shape[0], -1) |
|
|
|
if config.Data.pose.expression: |
|
poses = torch.cat([poses, poses_exp], dim=-1) |
|
return poses |
|
|
|
|
|
def get_joint(smplx_model, betas, pred): |
|
joint = smplx_model(betas=betas.repeat(pred.shape[0], 1), |
|
expression=pred[:, 165:265], |
|
jaw_pose=pred[:, 0:3], |
|
leye_pose=pred[:, 3:6], |
|
reye_pose=pred[:, 6:9], |
|
global_orient=pred[:, 9:12], |
|
body_pose=pred[:, 12:75], |
|
left_hand_pose=pred[:, 75:120], |
|
right_hand_pose=pred[:, 120:165], |
|
return_verts=True)['joints'] |
|
return joint |
|
|
|
|
|
def get_joints(smplx_model, betas, pred): |
|
if len(pred.shape) == 3: |
|
B = pred.shape[0] |
|
x = 4 if B>= 4 else B |
|
T = pred.shape[1] |
|
pred = pred.reshape(-1, 265) |
|
smplx_model.batch_size = L = T * x |
|
|
|
times = pred.shape[0] // smplx_model.batch_size |
|
joints = [] |
|
for i in range(times): |
|
joints.append(get_joint(smplx_model, betas, pred[i*L:(i+1)*L])) |
|
joints = torch.cat(joints, dim=0) |
|
joints = joints.reshape(B, T, -1, 3) |
|
else: |
|
smplx_model.batch_size = pred.shape[0] |
|
joints = get_joint(smplx_model, betas, pred) |
|
return joints |