bill-jiang's picture
Init
4409449
raw
history blame
5.83 kB
# -*- coding: utf-8 -*-
# Max-Planck-Gesellschaft zur Förderung der Wissenschaften e.V. (MPG) is
# holder of all proprietary rights on this computer program.
# You can only use this computer program if you have closed
# a license agreement with MPG or you get the right to use the computer
# program from someone who is authorized to grant you that right.
# Any use of the computer program without a valid license is prohibited and
# liable to prosecution.
#
# Copyright©2020 Max-Planck-Gesellschaft zur Förderung
# der Wissenschaften e.V. (MPG). acting on behalf of its Max Planck Institute
# for Intelligent Systems. All rights reserved.
#
# Contact: [email protected]
from typing import Optional
from torch import Tensor
import smplx
from .base import Datastruct, dataclass, Transform
from .rots2rfeats import Rots2Rfeats
from .rots2joints import Rots2Joints
from .joints2jfeats import Joints2Jfeats
class SMPLTransform(Transform):
def __init__(self, rots2rfeats: Rots2Rfeats,
rots2joints: Rots2Joints,
joints2jfeats: Joints2Jfeats,
**kwargs):
self.rots2rfeats = rots2rfeats
self.rots2joints = rots2joints
self.joints2jfeats = joints2jfeats
def Datastruct(self, **kwargs):
return SMPLDatastruct(_rots2rfeats=self.rots2rfeats,
_rots2joints=self.rots2joints,
_joints2jfeats=self.joints2jfeats,
transforms=self,
**kwargs)
def __repr__(self):
return "SMPLTransform()"
class RotIdentityTransform(Transform):
def __init__(self, **kwargs):
return
def Datastruct(self, **kwargs):
return RotTransDatastruct(**kwargs)
def __repr__(self):
return "RotIdentityTransform()"
@dataclass
class RotTransDatastruct(Datastruct):
rots: Tensor
trans: Tensor
transforms: RotIdentityTransform = RotIdentityTransform()
def __post_init__(self):
self.datakeys = ["rots", "trans"]
def __len__(self):
return len(self.rots)
@dataclass
class SMPLDatastruct(Datastruct):
transforms: SMPLTransform
_rots2rfeats: Rots2Rfeats
_rots2joints: Rots2Joints
_joints2jfeats: Joints2Jfeats
features: Optional[Tensor] = None
rots_: Optional[RotTransDatastruct] = None
rfeats_: Optional[Tensor] = None
joints_: Optional[Tensor] = None
jfeats_: Optional[Tensor] = None
vertices_: Optional[Tensor] = None
def __post_init__(self):
self.datakeys = ['features', 'rots_', 'rfeats_',
'joints_', 'jfeats_', 'vertices_']
# starting point
if self.features is not None and self.rfeats_ is None:
self.rfeats_ = self.features
@property
def rots(self):
# Cached value
if self.rots_ is not None:
return self.rots_
# self.rfeats_ should be defined
assert self.rfeats_ is not None
self._rots2rfeats.to(self.rfeats.device)
self.rots_ = self._rots2rfeats.inverse(self.rfeats)
return self.rots_
@property
def rfeats(self):
# Cached value
if self.rfeats_ is not None:
return self.rfeats_
# self.rots_ should be defined
assert self.rots_ is not None
self._rots2rfeats.to(self.rots.device)
self.rfeats_ = self._rots2rfeats(self.rots)
return self.rfeats_
@property
def joints(self):
# Cached value
if self.joints_ is not None:
return self.joints_
self._rots2joints.to(self.rots.device)
self.joints_ = self._rots2joints(self.rots)
return self.joints_
@property
def jfeats(self):
# Cached value
if self.jfeats_ is not None:
return self.jfeats_
self._joints2jfeats.to(self.joints.device)
self.jfeats_ = self._joints2jfeats(self.joints)
return self.jfeats_
@property
def vertices(self):
# Cached value
if self.vertices_ is not None:
return self.vertices_
self._rots2joints.to(self.rots.device)
self.vertices_ = self._rots2joints(self.rots, jointstype="vertices")
return self.vertices_
def __len__(self):
return len(self.rfeats)
def get_body_model(model_type, gender, batch_size, device='cpu', ext='pkl'):
'''
type: smpl, smplx smplh and others. Refer to smplx tutorial
gender: male, female, neutral
batch_size: an positive integar
'''
mtype = model_type.upper()
if gender != 'neutral':
if not isinstance(gender, str):
gender = str(gender.astype(str)).upper()
else:
gender = gender.upper()
else:
gender = gender.upper()
ext = 'npz'
body_model_path = f'data/smpl_models/{model_type}/{mtype}_{gender}.{ext}'
body_model = smplx.create(body_model_path, model_type=type,
gender=gender, ext=ext,
use_pca=False,
num_pca_comps=12,
create_global_orient=True,
create_body_pose=True,
create_betas=True,
create_left_hand_pose=True,
create_right_hand_pose=True,
create_expression=True,
create_jaw_pose=True,
create_leye_pose=True,
create_reye_pose=True,
create_transl=True,
batch_size=batch_size)
if device == 'cuda':
return body_model.cuda()
else:
return body_model