Spaces:
Build error
Build error
# -*- coding: utf-8 -*- | |
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is | |
# holder of all proprietary rights on this computer program. | |
# You can only use this computer program if you have closed | |
# a license agreement with MPG or you get the right to use the computer | |
# program from someone who is authorized to grant you that right. | |
# Any use of the computer program without a valid license is prohibited and | |
# liable to prosecution. | |
# | |
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung | |
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute | |
# for Intelligent Systems. All rights reserved. | |
# | |
# Contact: [email protected] | |
import contextlib | |
from typing import Optional | |
import torch | |
from torch import Tensor | |
from mGPT.utils.joints import smplh_to_mmm_scaling_factor, smplh2mmm_indexes, get_root_idx | |
from mGPT.utils.easyconvert import rep_to_rep | |
from .base import Rots2Joints | |
def slice_or_none(data, cslice): | |
if data is None: | |
return data | |
else: | |
return data[cslice] | |
class SMPLX(Rots2Joints): | |
def __init__(self, | |
path: str, | |
jointstype: str = "mmm", | |
input_pose_rep: str = "matrix", | |
batch_size: int = 512, | |
gender="neutral", | |
**kwargs) -> None: | |
super().__init__(path=None, normalization=False) | |
self.batch_size = batch_size | |
self.input_pose_rep = input_pose_rep | |
self.jointstype = jointstype | |
self.training = False | |
from smplx.body_models import SMPLXLayer | |
import os | |
# rel_p = path.split('/') | |
# rel_p = rel_p[rel_p.index('data'):] | |
# rel_p = '/'.join(rel_p) | |
# Remove annoying print | |
with contextlib.redirect_stdout(None): | |
self.smplx = SMPLXLayer(path, | |
ext="npz", | |
gender=gender, | |
batch_size=batch_size).eval() | |
self.faces = self.smplx.faces | |
for p in self.parameters(): | |
p.requires_grad = False | |
def train(self, *args, **kwargs): | |
return self | |
def forward(self, | |
smpl_data: dict, | |
jointstype: Optional[str] = None, | |
input_pose_rep: Optional[str] = None, | |
batch_size: Optional[int] = None) -> Tensor: | |
# Take values from init if not specified there | |
jointstype = self.jointstype if jointstype is None else jointstype | |
batch_size = self.batch_size if batch_size is None else batch_size | |
input_pose_rep = self.input_pose_rep if input_pose_rep is None else input_pose_rep | |
poses = smpl_data.rots | |
trans = smpl_data.trans | |
from functools import reduce | |
import operator | |
save_shape_bs_len = poses.shape[:-3] | |
nposes = reduce(operator.mul, save_shape_bs_len, 1) | |
matrix_poses = rep_to_rep(self.input_pose_rep, input_pose_rep, poses) | |
# Reshaping | |
matrix_poses = matrix_poses.reshape((nposes, *matrix_poses.shape[-3:])) | |
global_orient = matrix_poses[:, 0] | |
if trans is None: | |
trans = torch.zeros((*save_shape_bs_len, 3), | |
dtype=poses.dtype, | |
device=poses.device) | |
trans_all = trans.reshape((nposes, *trans.shape[-1:])) | |
body_pose = matrix_poses[:, 1:22] | |
if poses.shape[-3] == 55: | |
nohands = False | |
nofaces = False | |
elif poses.shape[-3] == 52: | |
nohands = False | |
nofaces = True | |
elif poses.shape[-3] == 22: | |
nohands = True | |
nofaces = True | |
else: | |
raise NotImplementedError("Could not parse the poses.") | |
if nohands: | |
left_hand_pose = None | |
right_hand_pose = None | |
else: | |
left_hand_pose = matrix_poses[:, 25:40] | |
right_hand_pose = matrix_poses[:, 40:55] | |
if nofaces: | |
jaw_pose = None | |
leye_pose = None | |
reye_pose = None | |
else: | |
jaw_pose = matrix_poses[:, 22:23] | |
leye_pose = matrix_poses[:, 23:24] | |
reye_pose = matrix_poses[:, 24:25] | |
n = len(body_pose) | |
outputs = [] | |
for chunk in range(int((n - 1) / batch_size) + 1): | |
chunk_slice = slice(chunk * batch_size, (chunk + 1) * batch_size) | |
smpl_output = self.smplx( | |
global_orient=slice_or_none(global_orient, chunk_slice), | |
body_pose=slice_or_none(body_pose, chunk_slice), | |
left_hand_pose=slice_or_none(left_hand_pose, chunk_slice), | |
right_hand_pose=slice_or_none(right_hand_pose, chunk_slice), | |
jaw_pose=slice_or_none(jaw_pose, chunk_slice), | |
leye_pose=slice_or_none(leye_pose, chunk_slice), | |
reye_pose=slice_or_none(reye_pose, chunk_slice), | |
transl=slice_or_none(trans_all, chunk_slice)) | |
if jointstype == "vertices": | |
output_chunk = smpl_output.vertices | |
else: | |
joints = smpl_output.joints | |
output_chunk = joints | |
outputs.append(output_chunk) | |
outputs = torch.cat(outputs) | |
outputs = outputs.reshape((*save_shape_bs_len, *outputs.shape[1:])) | |
# Change topology if needed | |
outputs = smplx_to(jointstype, outputs, trans) | |
return outputs | |
def inverse(self, joints: Tensor) -> Tensor: | |
raise NotImplementedError("Cannot inverse SMPLX layer.") | |
def smplx_to(jointstype, data, trans): | |
if "mmm" in jointstype: | |
indexes = smplh2mmm_indexes | |
data = data[..., indexes, :] | |
# make it compatible with mmm | |
if jointstype == "mmm": | |
data *= smplh_to_mmm_scaling_factor | |
if jointstype == "smplmmm": | |
pass | |
elif jointstype in ["mmm", "mmmns"]: | |
# swap axis | |
data = data[..., [1, 2, 0]] | |
# revert left and right | |
data[..., 2] = -data[..., 2] | |
elif jointstype == "smplnh": | |
from mGPT.utils.joints import smplh2smplnh_indexes | |
indexes = smplh2smplnh_indexes | |
data = data[..., indexes, :] | |
elif jointstype == "smplh": | |
pass | |
elif jointstype == "vertices": | |
pass | |
else: | |
raise NotImplementedError(f"SMPLX to {jointstype} is not implemented.") | |
if jointstype != "vertices": | |
# shift the output in each batch | |
# such that it is centered on the pelvis/root on the first frame | |
root_joint_idx = get_root_idx(jointstype) | |
shift = trans[..., 0, :] - data[..., 0, root_joint_idx, :] | |
data += shift[..., None, None, :] | |
return data | |