bill-jiang's picture
Init
4409449
raw
history blame
12.7 kB
import torch
import os, sys
import pickle
import smplx
import numpy as np
from tqdm import tqdm
sys.path.append(os.path.dirname(__file__))
from customloss import (camera_fitting_loss,
body_fitting_loss,
camera_fitting_loss_3d,
body_fitting_loss_3d,
)
from prior import MaxMixturePrior
import config
@torch.no_grad()
def guess_init_3d(model_joints,
j3d,
joints_category="orig"):
"""Initialize the camera translation via triangle similarity, by using the torso joints .
:param model_joints: SMPL model with pre joints
:param j3d: 25x3 array of Kinect Joints
:returns: 3D vector corresponding to the estimated camera translation
"""
# get the indexed four
gt_joints = ['RHip', 'LHip', 'RShoulder', 'LShoulder']
gt_joints_ind = [config.JOINT_MAP[joint] for joint in gt_joints]
if joints_category=="orig":
joints_ind_category = [config.JOINT_MAP[joint] for joint in gt_joints]
elif joints_category=="AMASS":
joints_ind_category = [config.AMASS_JOINT_MAP[joint] for joint in gt_joints]
elif joints_category=="MMM":
joints_ind_category = [config.MMM_JOINT_MAP[joint] for joint in gt_joints]
else:
print("NO SUCH JOINTS CATEGORY!")
sum_init_t = (j3d[:, joints_ind_category] - model_joints[:, gt_joints_ind]).sum(dim=1)
init_t = sum_init_t / 4.0
return init_t
# SMPLIfy 3D
class SMPLify3D():
"""Implementation of SMPLify, use 3D joints."""
def __init__(self,
smplxmodel,
step_size=1e-2,
batch_size=1,
num_iters=100,
use_collision=False,
use_lbfgs=True,
joints_category="orig",
device=torch.device('cuda:0'),
):
# Store options
self.batch_size = batch_size
self.device = device
self.step_size = step_size
self.num_iters = num_iters
# --- choose optimizer
self.use_lbfgs = use_lbfgs
# GMM pose prior
self.pose_prior = MaxMixturePrior(prior_folder=config.GMM_MODEL_DIR,
num_gaussians=8,
dtype=torch.float32).to(device)
# collision part
self.use_collision = use_collision
if self.use_collision:
self.part_segm_fn = config.Part_Seg_DIR
# reLoad SMPL-X model
self.smpl = smplxmodel
self.model_faces = smplxmodel.faces_tensor.view(-1)
# select joint joint_category
self.joints_category = joints_category
if joints_category=="orig":
self.smpl_index = config.full_smpl_idx
self.corr_index = config.full_smpl_idx
elif joints_category=="AMASS":
self.smpl_index = config.amass_smpl_idx
self.corr_index = config.amass_idx
# elif joints_category=="MMM":
# self.smpl_index = config.mmm_smpl_dix
# self.corr_index = config.mmm_idx
else:
self.smpl_index = None
self.corr_index = None
print("NO SUCH JOINTS CATEGORY!")
# ---- get the man function here ------
def __call__(self, init_pose, init_betas, init_cam_t, j3d, conf_3d=1.0, seq_ind=0):
"""Perform body fitting.
Input:
init_pose: SMPL pose estimate
init_betas: SMPL betas estimate
init_cam_t: Camera translation estimate
j3d: joints 3d aka keypoints
conf_3d: confidence for 3d joints
seq_ind: index of the sequence
Returns:
vertices: Vertices of optimized shape
joints: 3D joints of optimized shape
pose: SMPL pose parameters of optimized shape
betas: SMPL beta parameters of optimized shape
camera_translation: Camera translation
"""
# # # add the mesh inter-section to avoid
search_tree = None
pen_distance = None
filter_faces = None
if self.use_collision:
from mesh_intersection.bvh_search_tree import BVH
import mesh_intersection.loss as collisions_loss
from mesh_intersection.filter_faces import FilterFaces
search_tree = BVH(max_collisions=8)
pen_distance = collisions_loss.DistanceFieldPenetrationLoss(
sigma=0.5, point2plane=False, vectorized=True, penalize_outside=True)
if self.part_segm_fn:
# Read the part segmentation
part_segm_fn = os.path.expandvars(self.part_segm_fn)
with open(part_segm_fn, 'rb') as faces_parents_file:
face_segm_data = pickle.load(faces_parents_file, encoding='latin1')
faces_segm = face_segm_data['segm']
faces_parents = face_segm_data['parents']
# Create the module used to filter invalid collision pairs
filter_faces = FilterFaces(
faces_segm=faces_segm, faces_parents=faces_parents,
ign_part_pairs=None).to(device=self.device)
# Split SMPL pose to body pose and global orientation
body_pose = init_pose[:, 3:].detach().clone()
global_orient = init_pose[:, :3].detach().clone()
betas = init_betas.detach().clone()
# use guess 3d to get the initial
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
init_cam_t = guess_init_3d(model_joints, j3d, self.joints_category).detach()
camera_translation = init_cam_t.clone()
preserve_pose = init_pose[:, 3:].detach().clone()
# -------------Step 1: Optimize camera translation and body orientation--------
# Optimize only camera translation and body orientation
body_pose.requires_grad = False
betas.requires_grad = False
global_orient.requires_grad = True
camera_translation.requires_grad = True
camera_opt_params = [global_orient, camera_translation]
if self.use_lbfgs:
camera_optimizer = torch.optim.LBFGS(camera_opt_params, max_iter=self.num_iters,
lr=self.step_size, line_search_fn='strong_wolfe')
for i in range(10):
def closure():
camera_optimizer.zero_grad()
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
loss = camera_fitting_loss_3d(model_joints, camera_translation,
init_cam_t, j3d, self.joints_category)
loss.backward()
return loss
camera_optimizer.step(closure)
else:
camera_optimizer = torch.optim.Adam(camera_opt_params, lr=self.step_size, betas=(0.9, 0.999))
for i in range(20):
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
loss = camera_fitting_loss_3d(model_joints[:, self.smpl_index], camera_translation,
init_cam_t, j3d[:, self.corr_index], self.joints_category)
camera_optimizer.zero_grad()
loss.backward()
camera_optimizer.step()
# Fix camera translation after optimizing camera
# --------Step 2: Optimize body joints --------------------------
# Optimize only the body pose and global orientation of the body
body_pose.requires_grad = True
global_orient.requires_grad = True
camera_translation.requires_grad = True
# --- if we use the sequence, fix the shape
if seq_ind == 0:
betas.requires_grad = True
body_opt_params = [body_pose, betas, global_orient, camera_translation]
else:
betas.requires_grad = False
body_opt_params = [body_pose, global_orient, camera_translation]
if self.use_lbfgs:
body_optimizer = torch.optim.LBFGS(body_opt_params, max_iter=self.num_iters,
lr=self.step_size, line_search_fn='strong_wolfe')
for i in tqdm(range(self.num_iters), desc=f"LBFGS iter: "):
# for i in range(self.num_iters):
def closure():
body_optimizer.zero_grad()
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
model_vertices = smpl_output.vertices
loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
j3d[:, self.corr_index], self.pose_prior,
joints3d_conf=conf_3d,
joint_loss_weight=600.0,
pose_preserve_weight=5.0,
use_collision=self.use_collision,
model_vertices=model_vertices, model_faces=self.model_faces,
search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
loss.backward()
return loss
body_optimizer.step(closure)
else:
body_optimizer = torch.optim.Adam(body_opt_params, lr=self.step_size, betas=(0.9, 0.999))
for i in range(self.num_iters):
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas)
model_joints = smpl_output.joints
model_vertices = smpl_output.vertices
loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
j3d[:, self.corr_index], self.pose_prior,
joints3d_conf=conf_3d,
joint_loss_weight=600.0,
use_collision=self.use_collision,
model_vertices=model_vertices, model_faces=self.model_faces,
search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
body_optimizer.zero_grad()
loss.backward()
body_optimizer.step()
# Get final loss value
with torch.no_grad():
smpl_output = self.smpl(global_orient=global_orient,
body_pose=body_pose,
betas=betas, return_full_pose=True)
model_joints = smpl_output.joints
model_vertices = smpl_output.vertices
final_loss = body_fitting_loss_3d(body_pose, preserve_pose, betas, model_joints[:, self.smpl_index], camera_translation,
j3d[:, self.corr_index], self.pose_prior,
joints3d_conf=conf_3d,
joint_loss_weight=600.0,
use_collision=self.use_collision, model_vertices=model_vertices, model_faces=self.model_faces,
search_tree=search_tree, pen_distance=pen_distance, filter_faces=filter_faces)
vertices = smpl_output.vertices.detach()
joints = smpl_output.joints.detach()
pose = torch.cat([global_orient, body_pose], dim=-1).detach()
betas = betas.detach()
return vertices, joints, pose, betas, camera_translation, final_loss