|
|
|
|
|
|
|
|
|
import torch |
|
from torch import nn |
|
from torch import nn |
|
import smplx |
|
import torch |
|
import numpy as np |
|
import pose_utils |
|
from pose_utils import inverse_perspective_projection, perspective_projection |
|
import roma |
|
import pickle |
|
import os |
|
from pose_utils.constants_service import SMPLX_DIR |
|
from pose_utils.rot6d import rotation_6d_to_matrix |
|
from smplx.lbs import vertices2joints |
|
|
|
|
|
class SMPL_Layer(nn.Module): |
|
""" |
|
Extension of the SMPL Layer with information about the camera for (inverse) projection the camera plane. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
smpl_dir, |
|
type="smplx", |
|
gender="neutral", |
|
num_betas=10, |
|
kid=False, |
|
person_center=None, |
|
*args, |
|
**kwargs, |
|
): |
|
super().__init__() |
|
|
|
|
|
assert type == "smplx" |
|
self.type = type |
|
self.kid = kid |
|
self.num_betas = num_betas |
|
self.bm_x = smplx.create( |
|
smpl_dir, "smplx", gender=gender, use_pca=False, flat_hand_mean=True, num_betas=num_betas |
|
) |
|
|
|
|
|
self.joint_names = eval(f"pose_utils.get_{self.type}_joint_names")() |
|
self.person_center = person_center |
|
self.person_center_idx = None |
|
if self.person_center is not None: |
|
self.person_center_idx = self.joint_names.index(self.person_center) |
|
|
|
def forward( |
|
self, |
|
pose, |
|
shape, |
|
loc, |
|
dist, |
|
transl, |
|
K, |
|
expression=None, |
|
rot6d=False, |
|
j_regressor=None, |
|
): |
|
""" |
|
Args: |
|
- pose: pose of the person in axis-angle - torch.Tensor [bs,24,3] |
|
- shape: torch.Tensor [bs,10] |
|
- loc: 2D location of the pelvis in pixel space - torch.Tensor [bs,2] |
|
- dist: distance of the pelvis from the camera in m - torch.Tensor [bs,1] |
|
Return: |
|
- dict containing a bunch of useful information about each person |
|
""" |
|
|
|
if loc is not None and dist is not None: |
|
assert pose.shape[0] == shape.shape[0] == loc.shape[0] == dist.shape[0] |
|
POSE_TYPE_LENGTH = 6 if rot6d else 3 |
|
if self.type == "smpl": |
|
assert len(pose.shape) == 3 and list(pose.shape[1:]) == [24, POSE_TYPE_LENGTH] |
|
elif self.type == "smplx": |
|
assert len(pose.shape) == 3 and list(pose.shape[1:]) == [ |
|
53, |
|
POSE_TYPE_LENGTH, |
|
] |
|
else: |
|
raise NameError |
|
assert len(shape.shape) == 2 and ( |
|
list(shape.shape[1:]) == [self.num_betas] or list(shape.shape[1:]) == [self.num_betas + 1] |
|
) |
|
if loc is not None and dist is not None: |
|
assert len(loc.shape) == 2 and list(loc.shape[1:]) == [2] |
|
assert len(dist.shape) == 2 and list(dist.shape[1:]) == [1] |
|
|
|
bs = pose.shape[0] |
|
|
|
out = {} |
|
|
|
|
|
if bs == 0: |
|
return {} |
|
|
|
|
|
kwargs_pose = { |
|
"betas": shape, |
|
} |
|
kwargs_pose["global_orient"] = self.bm_x.global_orient.repeat(bs, 1) |
|
kwargs_pose["body_pose"] = pose[:, 1:22].flatten(1) |
|
kwargs_pose["left_hand_pose"] = pose[:, 22:37].flatten(1) |
|
kwargs_pose["right_hand_pose"] = pose[:, 37:52].flatten(1) |
|
kwargs_pose["jaw_pose"] = pose[:, 52:53].flatten(1) |
|
|
|
if expression is not None: |
|
kwargs_pose["expression"] = expression.flatten(1) |
|
else: |
|
kwargs_pose["expression"] = self.bm_x.expression.repeat(bs, 1) |
|
|
|
|
|
kwargs_pose["leye_pose"] = self.bm_x.leye_pose.repeat(bs, 1) |
|
kwargs_pose["reye_pose"] = self.bm_x.reye_pose.repeat(bs, 1) |
|
|
|
|
|
output = self.bm_x(pose2rot=not rot6d, **kwargs_pose) |
|
verts = output.vertices |
|
j3d = output.joints |
|
|
|
if rot6d: |
|
R = rotation_6d_to_matrix(pose[:, 0]) |
|
else: |
|
R = roma.rotvec_to_rotmat(pose[:, 0]) |
|
|
|
|
|
pelvis = j3d[:, [0]] |
|
j3d = (R.unsqueeze(1) @ (j3d - pelvis).unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
verts = (R.unsqueeze(1) @ (verts - pelvis).unsqueeze(-1)).squeeze(-1) |
|
|
|
|
|
if transl is None: |
|
if K.dtype == torch.float16: |
|
|
|
transl = inverse_perspective_projection( |
|
loc.unsqueeze(1).float(), K.float(), dist.unsqueeze(1).float() |
|
)[:, 0] |
|
transl = transl.half() |
|
else: |
|
transl = inverse_perspective_projection(loc.unsqueeze(1), K, dist.unsqueeze(1))[:, 0] |
|
|
|
|
|
transl_up = transl.clone() |
|
|
|
|
|
if self.person_center_idx is None: |
|
|
|
transl_up = transl_up + pelvis[:, 0] |
|
else: |
|
|
|
person_center = j3d[:, [self.person_center_idx]] |
|
verts = verts - person_center |
|
j3d = j3d - person_center |
|
|
|
|
|
j3d_cam = j3d + transl_up.unsqueeze(1) |
|
verts_cam = verts + transl_up.unsqueeze(1) |
|
|
|
|
|
if j_regressor is not None: |
|
|
|
j3d_cam = vertices2joints(j_regressor, verts_cam) |
|
j2d = perspective_projection(j3d_cam, K) |
|
v2d = perspective_projection(verts_cam, K) |
|
|
|
out.update( |
|
{ |
|
"v3d": verts_cam, |
|
"j3d": j3d_cam, |
|
"j2d": j2d, |
|
"v2d": v2d, |
|
"transl": transl, |
|
"transl_pelvis": transl.unsqueeze(1) - person_center - pelvis, |
|
"j3d_world": output.joints, |
|
} |
|
) |
|
|
|
return out |
|
|
|
def forward_local(self, pose, shape): |
|
N, J, L = pose.shape |
|
if N < 1: |
|
return None |
|
kwargs_pose = { |
|
"betas": shape, |
|
} |
|
if J == 53: |
|
kwargs_pose["global_orient"] = self.bm_x.global_orient.repeat(N, 1) |
|
kwargs_pose["body_pose"] = pose[:, 1:22].flatten(1) |
|
kwargs_pose["left_hand_pose"] = pose[:, 22:37].flatten(1) |
|
kwargs_pose["right_hand_pose"] = pose[:, 37:52].flatten(1) |
|
kwargs_pose["jaw_pose"] = pose[:, 52:53].flatten(1) |
|
elif J==55: |
|
kwargs_pose["global_orient"] = self.bm_x.global_orient.repeat(N, 1) |
|
kwargs_pose["body_pose"] = pose[:, 1:22].flatten(1) |
|
kwargs_pose["left_hand_pose"] = pose[:, 25:40].flatten(1) |
|
kwargs_pose["right_hand_pose"] = pose[:, 40:55].flatten(1) |
|
kwargs_pose["jaw_pose"] = pose[:, 22:23].flatten(1) |
|
else: |
|
raise ValueError(f"pose dim error, should be 53 or 55, but got {J}") |
|
kwargs_pose["expression"] = self.bm_x.expression.repeat(N, 1) |
|
|
|
|
|
kwargs_pose["leye_pose"] = self.bm_x.leye_pose.repeat(N, 1) |
|
kwargs_pose["reye_pose"] = self.bm_x.reye_pose.repeat(N, 1) |
|
|
|
output = self.bm_x(**kwargs_pose) |
|
return output |
|
def convert_standard_pose(self, poses): |
|
|
|
n = poses.shape[0] |
|
poses = torch.cat( |
|
[ |
|
poses[:, :22], |
|
poses[:, 52:53], |
|
self.bm_x.leye_pose.repeat(n, 1, 1), |
|
self.bm_x.reye_pose.repeat(n, 1, 1), |
|
poses[:, 22:52], |
|
], |
|
dim=1, |
|
) |
|
return poses |
|
|