AiOS / detrsmpl /models /heads /expose_head.py
ttxskk
update
d7e58f0
import os
import pickle
from abc import abstractmethod
from typing import List, Optional
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from mmcv.cnn import build_activation_layer, initialize
from mmcv.runner.base_module import BaseModule
from detrsmpl.utils.geometry import rot6d_to_rotmat
class IterativeRegression(nn.Module):
"""Regressor for ExPose Head."""
def __init__(self,
module,
mean_param,
num_stages=1,
append_params=True,
learn_mean=False,
detach_mean=False,
dim=1,
**kwargs):
super(IterativeRegression, self).__init__()
self.module = module
self._num_stages = num_stages
self.dim = dim
if learn_mean:
self.register_parameter(
'mean_param', nn.Parameter(mean_param, requires_grad=True))
else:
self.register_buffer('mean_param', mean_param)
self.append_params = append_params
self.detach_mean = detach_mean
def get_mean(self):
"""Get the initial mean param."""
return self.mean_param.clone()
@property
def num_stages(self):
return self._num_stages
def forward(self,
features: torch.Tensor,
cond: Optional[torch.Tensor] = None):
''' Computes deltas on top of condition iteratively
Parameters
----------
features: torch.Tensor
Input features
'''
batch_size = features.shape[0]
expand_shape = [batch_size] + [-1] * len(features.shape[1:])
parameters = []
deltas = []
module_input = features
if cond is None:
cond = self.mean_param.expand(*expand_shape).clone()
# Detach mean
if self.detach_mean:
cond = cond.detach()
if self.append_params:
assert features is not None, (
'Features are none even though append_params is True')
module_input = torch.cat([module_input, cond], dim=self.dim)
deltas.append(self.module(module_input))
num_params = deltas[-1].shape[1]
parameters.append(cond[:, :num_params].clone() + deltas[-1])
for stage_idx in range(1, self.num_stages):
module_input = torch.cat([features, parameters[stage_idx - 1]],
dim=-1)
params_upd = self.module(module_input)
deltas.append(params_upd)
parameters.append(parameters[stage_idx - 1] + params_upd)
return parameters
class MLP(nn.Module):
"""MLP
Args:
input_dim (int): Input dim of MLP.
output_dim (int): Output dim of MLP.
layers (List): Layer dims.
activ_type (str): Activation layer type.
dropout (float): Dropout.
gain (float): Xavier init gain value.
"""
def __init__(
self,
input_dim: int,
output_dim: int,
layers: List[int] = [],
activ_type: str = 'relu',
dropout: float = 0.5,
gain: float = 0.01,
):
super(MLP, self).__init__()
curr_input_dim = input_dim
self.num_layers = len(layers)
self.blocks = nn.ModuleList()
for layer_idx, layer_dim in enumerate(layers):
if activ_type == 'none':
active = None
else:
active = build_activation_layer(
cfg=dict(type=activ_type, inplace=True))
linear = nn.Linear(curr_input_dim, layer_dim, bias=True)
curr_input_dim = layer_dim
layer = []
layer.append(linear)
if active is not None:
layer.append(active)
if dropout > 0.0:
layer.append(nn.Dropout(dropout))
block = nn.Sequential(*layer)
self.add_module('layer_{:03d}'.format(layer_idx), block)
self.blocks.append(block)
self.output_layer = nn.Linear(curr_input_dim, output_dim)
initialize(self.output_layer,
init_cfg=dict(type='Xavier',
gain=gain,
distribution='uniform'))
def forward(self, module_input):
curr_input = module_input
for block in self.blocks:
curr_input = block(curr_input)
return self.output_layer(curr_input)
class ContinuousRotReprDecoder:
"""ExPose Decoder Decode latent representation to rotation.
Args:
num_angles (int): Joint num.
dtype: dtype.
mean (torch.tensor): Mean value for params.
"""
def __init__(self, num_angles, dtype=torch.float32, mean=None):
self.num_angles = num_angles
self.dtype = dtype
if isinstance(mean, dict):
mean = mean.get('cont_rot_repr', None)
if mean is None:
mean = torch.tensor([1.0, 0.0, 0.0, 1.0, 0.0, 0.0],
dtype=self.dtype).unsqueeze(dim=0).expand(
self.num_angles, -1).contiguous().view(-1)
if not torch.is_tensor(mean):
mean = torch.tensor(mean)
mean = mean.reshape(-1, 6)
if mean.shape[0] < self.num_angles:
mean = mean.repeat(self.num_angles // mean.shape[0] + 1,
1).contiguous()
mean = mean[:self.num_angles]
elif mean.shape[0] > self.num_angles:
mean = mean[:self.num_angles]
mean = mean.reshape(-1)
self.mean = mean
def get_mean(self):
return self.mean.clone()
def get_dim_size(self):
return self.num_angles * 6
def __call__(self, module_input):
batch_size = module_input.shape[0]
reshaped_input = module_input.view(-1, 6)
rot_mats = rot6d_to_rotmat(reshaped_input)
# aa = rot6d_to_aa(reshaped_input)
# return aa.view(batch_size,-1,3)
return rot_mats.view(batch_size, -1, 3, 3)
class ExPoseHead(BaseModule):
"""General Head for ExPose."""
def __init__(self, init_cfg=None):
super().__init__(init_cfg)
def load_regressor(self,
input_feat_dim: int = 2048,
param_mean: torch.Tensor = None,
regressor_cfg: dict = None):
"""Build regressor for ExPose Head."""
param_dim = param_mean.numel()
regressor = MLP(input_feat_dim + param_dim, param_dim, **regressor_cfg)
self.regressor = IterativeRegression(regressor,
param_mean,
num_stages=3)
def load_param_decoder(self, mean_poses_dict):
"""Build decoders for each pose."""
start = 0
mean_lst = []
self.pose_param_decoders = {}
for pose_param in self.pose_param_conf:
pose_name = pose_param['name']
num_angles = pose_param['num_angles']
if pose_param['use_mean']:
pose_decoder = ContinuousRotReprDecoder(
num_angles,
dtype=torch.float32,
mean=mean_poses_dict.get(pose_name, None))
else:
pose_decoder = ContinuousRotReprDecoder(num_angles,
dtype=torch.float32,
mean=None)
self.pose_param_decoders['{}_decoder'.format(
pose_name)] = pose_decoder
pose_dim = pose_decoder.get_dim_size()
pose_mean = pose_decoder.get_mean()
if pose_param['rotate_axis_x']:
pose_mean[3] = -1
idxs = list(range(start, start + pose_dim))
idxs = torch.tensor(idxs, dtype=torch.long)
self.register_buffer('{}_idxs'.format(pose_name), idxs)
start += pose_dim
mean_lst.append(pose_mean.view(-1))
return start, mean_lst
def get_camera_param(self, camera_cfg):
"""Build camera param."""
camera_pos_scale = camera_cfg.get('pos_func')
if camera_pos_scale == 'softplus':
camera_scale_func = F.softplus
elif camera_pos_scale == 'exp':
camera_scale_func = torch.exp
elif camera_pos_scale == 'none' or camera_pos_scale == 'None':
def func(x):
return x
camera_scale_func = func
mean_scale = camera_cfg.get('mean_scale', 0.9)
if camera_pos_scale == 'softplus':
mean_scale = np.log(np.exp(mean_scale) - 1)
elif camera_pos_scale == 'exp':
mean_scale = np.log(mean_scale)
camera_mean = torch.tensor([mean_scale, 0.0, 0.0], dtype=torch.float32)
camera_param_dim = 3
return camera_mean, camera_param_dim, camera_scale_func
def flat_params_to_dict(self, param_tensor):
"""Turn param tensors to dict."""
smplx_dict = {}
raw_dict = {}
for pose_param in self.pose_param_conf:
pose_name = pose_param['name']
pose_idxs = getattr(self, f'{pose_name}_idxs')
decoder = self.pose_param_decoders[f'{pose_name}_decoder']
pose = torch.index_select(param_tensor, 1, pose_idxs)
raw_dict[f'raw_{pose_name}'] = pose.clone()
smplx_dict[pose_name] = decoder(pose)
return smplx_dict, raw_dict
def get_mean(self, name, batch_size):
"""Get mean value of params."""
mean_param = self.regressor.get_mean().view(-1)
if name is None:
return mean_param.reshape(1, -1).expand(batch_size, -1)
idxs = getattr(self, f'{name}_idxs')
return mean_param[idxs].reshape(1, -1).expand(batch_size, -1)
def get_num_betas(self):
return self.num_betas
def get_num_expression_coeffs(self):
return self.num_expression_coeffs
@abstractmethod
def forward(self, features):
pass
class ExPoseBodyHead(ExPoseHead):
"""Head for ExPose Body Model."""
def __init__(self,
init_cfg=None,
num_betas: int = 10,
num_expression_coeffs: int = 10,
mean_pose_path: str = '',
shape_mean_path: str = '',
pose_param_conf: list = None,
input_feat_dim: int = 2048,
regressor_cfg: dict = None,
camera_cfg: dict = None):
super().__init__(init_cfg)
self.num_betas = num_betas
self.num_expression_coeffs = num_expression_coeffs
# poses
self.pose_param_conf = pose_param_conf
mean_poses_dict = {}
if os.path.exists(mean_pose_path):
with open(mean_pose_path, 'rb') as f:
mean_poses_dict = pickle.load(f)
start, mean_lst = self.load_param_decoder(mean_poses_dict)
# shape
if os.path.exists(shape_mean_path):
shape_mean = torch.from_numpy(
np.load(shape_mean_path,
allow_pickle=True)).to(dtype=torch.float32).reshape(
1, -1)[:, :num_betas].reshape(-1)
else:
shape_mean = torch.zeros([num_betas], dtype=torch.float32)
shape_idxs = list(range(start, start + num_betas))
self.register_buffer('shape_idxs',
torch.tensor(shape_idxs, dtype=torch.long))
start += num_betas
mean_lst.append(shape_mean.view(-1))
# expression
expression_mean = torch.zeros([num_expression_coeffs],
dtype=torch.float32)
expression_idxs = list(range(start, start + num_expression_coeffs))
self.register_buffer('expression_idxs',
torch.tensor(expression_idxs, dtype=torch.long))
start += num_expression_coeffs
mean_lst.append(expression_mean.view(-1))
# camera
mean, dim, scale_func = self.get_camera_param(camera_cfg)
self.camera_scale_func = scale_func
camera_idxs = list(range(start, start + dim))
self.register_buffer('camera_idxs',
torch.tensor(camera_idxs, dtype=torch.long))
start += dim
mean_lst.append(mean)
param_mean = torch.cat(mean_lst).view(1, -1)
self.load_regressor(input_feat_dim, param_mean, regressor_cfg)
def forward(self, features):
"""Forward function of ExPose Body Head.
Args:
features (List[torch.tensor]) : Output of restnet.
cond : Initial params. If none, use the mean params.
"""
body_parameters = self.regressor(features)[-1]
params_dict, raw_dict = self.flat_params_to_dict(body_parameters)
params_dict['betas'] = torch.index_select(body_parameters, 1,
self.shape_idxs)
params_dict['expression'] = torch.index_select(body_parameters, 1,
self.expression_idxs)
camera_params = torch.index_select(body_parameters, 1,
self.camera_idxs)
scale = camera_params[:, 0:1]
translation = camera_params[:, 1:3]
scale = self.camera_scale_func(scale)
camera_params = torch.cat([scale, translation], dim=1)
return {
'pred_param': params_dict,
'pred_cam': camera_params,
'pred_raw': raw_dict
}
class ExPoseHandHead(ExPoseHead):
"""Head for ExPose Hand Model."""
def __init__(self,
init_cfg=None,
num_betas: int = 10,
mean_pose_path: str = '',
pose_param_conf: list = None,
input_feat_dim: int = 2048,
regressor_cfg: dict = None,
camera_cfg: dict = None):
super().__init__(init_cfg)
self.num_betas = num_betas
# poses
self.pose_param_conf = pose_param_conf
mean_poses_dict = {}
if os.path.exists(mean_pose_path):
with open(mean_pose_path, 'rb') as f:
mean_poses_dict = pickle.load(f)
start, mean_lst = self.load_param_decoder(mean_poses_dict)
shape_mean = torch.zeros([num_betas], dtype=torch.float32)
shape_idxs = list(range(start, start + num_betas))
self.register_buffer('shape_idxs',
torch.tensor(shape_idxs, dtype=torch.long))
start += num_betas
mean_lst.append(shape_mean.view(-1))
# camera
mean, dim, scale_func = self.get_camera_param(camera_cfg)
self.camera_scale_func = scale_func
camera_idxs = list(range(start, start + dim))
self.register_buffer('camera_idxs',
torch.tensor(camera_idxs, dtype=torch.long))
start += dim
mean_lst.append(mean)
param_mean = torch.cat(mean_lst).view(1, -1)
self.load_regressor(input_feat_dim, param_mean, regressor_cfg)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, features, cond=None):
"""Forward function of ExPose Hand Head.
Args:
features (List[torch.tensor]) : Output of restnet.
cond : Initial params. If none, use the mean params.
"""
batch_size = features[-1].size(0)
features = self.avgpool(features[-1]).view(batch_size, -1)
hand_parameters = self.regressor(features, cond=cond)[-1]
params_dict, raw_dict = self.flat_params_to_dict(hand_parameters)
params_dict['betas'] = torch.index_select(hand_parameters, 1,
self.shape_idxs)
camera_params = torch.index_select(hand_parameters, 1,
self.camera_idxs)
scale = camera_params[:, 0:1]
translation = camera_params[:, 1:3]
scale = self.camera_scale_func(scale)
camera_params = torch.cat([scale, translation], dim=1)
return {
'pred_param': params_dict,
'pred_cam': camera_params,
'pred_raw': raw_dict
}
class ExPoseFaceHead(ExPoseHead):
"""Head for ExPose Face Model."""
def __init__(self,
init_cfg=None,
num_betas: int = 10,
num_expression_coeffs: int = 10,
pose_param_conf: list = None,
mean_pose_path: str = '',
input_feat_dim: int = 2048,
regressor_cfg: dict = None,
camera_cfg: dict = None):
super().__init__(init_cfg)
self.num_betas = num_betas
self.num_expression_coeffs = num_expression_coeffs
# poses
self.pose_param_conf = pose_param_conf
mean_poses_dict = {}
if os.path.exists(mean_pose_path):
with open(mean_pose_path, 'rb') as f:
mean_poses_dict = pickle.load(f)
start, mean_lst = self.load_param_decoder(mean_poses_dict)
# shape
shape_mean = torch.zeros([num_betas], dtype=torch.float32)
shape_idxs = list(range(start, start + num_betas))
self.register_buffer('shape_idxs',
torch.tensor(shape_idxs, dtype=torch.long))
start += num_betas
mean_lst.append(shape_mean.view(-1))
# expression
expression_mean = torch.zeros([num_expression_coeffs],
dtype=torch.float32)
expression_idxs = list(range(start, start + num_expression_coeffs))
self.register_buffer('expression_idxs',
torch.tensor(expression_idxs, dtype=torch.long))
start += num_expression_coeffs
mean_lst.append(expression_mean.view(-1))
# camera
mean, dim, scale_func = self.get_camera_param(camera_cfg)
self.camera_scale_func = scale_func
camera_idxs = list(range(start, start + dim))
self.register_buffer('camera_idxs',
torch.tensor(camera_idxs, dtype=torch.long))
start += dim
mean_lst.append(mean)
param_mean = torch.cat(mean_lst).view(1, -1)
self.load_regressor(input_feat_dim, param_mean, regressor_cfg)
self.avgpool = nn.AdaptiveAvgPool2d((1, 1))
def forward(self, features, cond=None):
"""Forward function of ExPose Face Head.
Args:
features (List[torch.tensor]) : Output of restnet.
cond : Initial params. If none, use the mean params.
"""
batch_size = features[-1].size(0)
features = self.avgpool(features[-1]).view(batch_size, -1)
head_parameters = self.regressor(features, cond=cond)[-1]
params_dict, raw_dict = self.flat_params_to_dict(head_parameters)
params_dict['betas'] = torch.index_select(head_parameters, 1,
self.shape_idxs)
params_dict['expression'] = torch.index_select(head_parameters, 1,
self.expression_idxs)
camera_params = torch.index_select(head_parameters, 1,
self.camera_idxs)
scale = camera_params[:, 0:1]
translation = camera_params[:, 1:3]
scale = self.camera_scale_func(scale)
camera_params = torch.cat([scale, translation], dim=1)
return {
'pred_param': params_dict,
'pred_cam': camera_params,
'pred_raw': raw_dict
}