Spaces:
Sleeping
Sleeping
import math | |
import torch | |
import torch.nn as nn | |
from . import geometry | |
from einops import rearrange | |
import itertools | |
def _joints_to_bones(joints, bones_idxs): | |
bones = [] | |
for a, b in bones_idxs: | |
bones += [torch.stack([joints[:, :, a, :], joints[:, :, b, :]], dim=2)] | |
bones = torch.stack(bones, dim=2) | |
return bones | |
def _compute_vertices_to_bones_weights(bones_pred, seq_shape_pred, temperature=1): | |
vertices_to_bones = [] | |
for i in range(bones_pred.shape[2]): | |
vertices_to_bones += [geometry.line_segment_distance(bones_pred[:, :, i, 0], bones_pred[:, :, i, 1], seq_shape_pred)] | |
# vertices_to_bones = nn.functional.softmax(1 / torch.stack(vertices_to_bones) / temperature, dim=0) | |
vertices_to_bones = nn.functional.softmax(-torch.stack(vertices_to_bones) / temperature, dim=0) | |
return vertices_to_bones | |
def build_kinematic_chain(n_bones, start_bone_idx): | |
# build bones and kinematic chain starting from leaf bone (body joint) | |
bones_to_joints = [] | |
kinematic_chain = [] | |
bone_idx = start_bone_idx | |
# bones from leaf to root | |
dependent_bones = [] | |
for i in range(n_bones): | |
bones_to_joints += [(i + 1, i)] | |
kinematic_chain = [(bone_idx, dependent_bones)] + kinematic_chain # parent is always in the front | |
dependent_bones = dependent_bones + [bone_idx] | |
bone_idx += 1 | |
return bones_to_joints, kinematic_chain, dependent_bones | |
def update_body_kinematic_chain(kinematic_chain, leg_kinematic_chain, body_bone_idx, leg_bone_idxs, attach_legs_to_body=True): | |
if attach_legs_to_body: | |
for bone_idx, dependent_bones in kinematic_chain: | |
if bone_idx == body_bone_idx or body_bone_idx in dependent_bones: | |
dependent_bones += leg_bone_idxs | |
kinematic_chain = kinematic_chain + leg_kinematic_chain # parent is always in the front | |
return kinematic_chain | |
def lift_points_mesh(points, seq_shape, size_aspect=0.5): | |
""" | |
for a set of points that's generated by linear interpolation, lift them in y-axis to match the actual bones | |
this operates on all the joint points except for the first and last one | |
""" | |
points_to_lift = points[:, :, 1:-1, :] | |
points_z_range_max = points_to_lift[..., 2] - size_aspect * (points_to_lift[..., 2] - points[:, :, :-2, 2]) | |
points_z_range_min = points_to_lift[..., 2] - size_aspect * (points_to_lift[..., 2] - points[:, :, 2:, 2]) | |
points_z_range_min = points_z_range_min.unsqueeze(-1).expand(-1, -1, -1, seq_shape.shape[-2]) | |
points_z_range_max = points_z_range_max.unsqueeze(-1).expand(-1, -1, -1, seq_shape.shape[-2]) | |
valid_points = seq_shape.unsqueeze(2).expand(-1, -1, points_to_lift.shape[-2], -1, -1) | |
valid_idx_1 = valid_points[..., 2] > points_z_range_min | |
valid_idx_2 = valid_points[..., 2] < points_z_range_max | |
valid_idx = valid_idx_1 * valid_idx_2 | |
valid_idx = valid_idx.float() | |
valid_y = valid_points[..., 1] * valid_idx + (-1e6) * (1 - valid_idx) | |
valid_y, _ = valid_y.max(dim=-1) | |
is_valid = valid_y != (-1e6) | |
is_valid = is_valid.float() | |
points[:, :, 1:-1, 1] = points[:, :, 1:-1, 1] * (1-is_valid) + valid_y * is_valid | |
return points | |
def estimate_bones(seq_shape, n_body_bones, resample=False, n_legs=4, n_leg_bones=0, body_bones_type='z_minmax', compute_kinematic_chain=True, aux=None, attach_legs_to_body=True, bone_y_threshold=None, body_bone_idx_preset=[3, 5, 5, 3]): | |
""" | |
Estimate the position and structure of bones given the mesh vertex positions. | |
Args: | |
seq_shape: a tensor of shape (B, F, V, 3), the batched position of mesh vertices. | |
n_body_bones: an integer, the desired number of bones. | |
Returns: | |
(bones_pred, kinematic_chain) where | |
bones_pred: a tensor of shape (B, F, num_bones, 2, 3) | |
kinematic_chain: a list of tuples of length n_body_bones; for each tuple, the first element is the bone index while | |
the second element is a list of bones indices of dependent bones. | |
""" | |
# preprocess shape | |
if resample: | |
b, _, n, _ = seq_shape.shape | |
seq_shape = geometry.sample_farthest_points(rearrange(seq_shape, 'b f n d -> (b f) d n'), n // 4) | |
seq_shape = rearrange(seq_shape, '(b f) d n -> b f n d', b=b) | |
if body_bones_type == 'max_distance': | |
raise NotImplementedError | |
# find two farthest points | |
# x is the symmetry plane, ignore it | |
# dists = torch.linalg.norm(seq_shape[:, :, None, :, 1:] - seq_shape[:, :, :, None, 1:], dim=-1) # Shape: (B, F, V, V) | |
# num_verts = dists.shape[-1] | |
# indices_flat = rearrange(dists, 'b f d1 d2 -> b f (d1 d2)').argmax(2) # Shape: (B, F) | |
# indices = torch.cat([(indices_flat // num_verts)[..., None], (indices_flat % num_verts)[..., None]], dim=2) # Shape: (B, F, 2) | |
# indices_gather = indices[..., None].repeat(1, 1, 1, 3) # Shape: (B, F, 2, 3) | |
# points = seq_shape.gather(2, indices_gather) # Shape: (B, F, 2, 3) | |
# fix the points order along z axis | |
# z_coordinate = points[:, :, :, 2] # Shape: (B, F, 2) | |
# front = z_coordinate < 0 | |
# point_a = rearrange(points[~front], '(b f) d -> b f d', b=seq_shape.shape[0]) # Shape: (B, F, 3) | |
# point_b = rearrange(points[front], '(b f) d -> b f d', b=seq_shape.shape[0]) # Shape: (B, F, 3) | |
elif body_bones_type == 'z_minmax': | |
indices_max, indices_min = seq_shape[..., 2].argmax(dim=2), seq_shape[..., 2].argmin(dim=2) | |
indices = torch.cat([indices_max[..., None], indices_min[..., None]], dim=2) | |
indices_gather = indices[..., None].repeat(1, 1, 1, 3) # Shape: (B, F, 2, 3) | |
points = seq_shape.gather(2, indices_gather) | |
point_a = points[:, :, 0, :] | |
point_b = points[:, :, 1, :] | |
elif body_bones_type == 'z_minmax_y+': | |
## TODO: mean may not be very robust, as inside is noisy | |
mid_point = seq_shape.mean(2) | |
seq_shape_pos_y_mask = (seq_shape[:, :, :, 1] > (mid_point[:, :, None, 1] - 0.5)).float() # y higher than midpoint | |
seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + (-1e6) * (1 - seq_shape_pos_y_mask) | |
indices = seq_shape_z.argmax(2) | |
indices_gather = indices[..., None, None].repeat(1, 1, 1, 3) | |
point_a = seq_shape.gather(2, indices_gather).squeeze(2) | |
seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + 1e6 * (1 - seq_shape_pos_y_mask) | |
indices = seq_shape_z.argmin(2) | |
indices_gather = indices[..., None, None].repeat(1, 1, 1, 3) | |
point_b = seq_shape.gather(2, indices_gather).squeeze(2) | |
elif body_bones_type == 'mine': | |
## TODO: mean may not be very robust, as inside is noisy | |
mid_point = seq_shape.mean(2) | |
seq_shape_pos_y_mask = (seq_shape[:, :, :, 1] > (mid_point[:, :, None, 1] - 0.5)).float() # y higher than midpoint | |
seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + (-1e6) * (1 - seq_shape_pos_y_mask) | |
indices = seq_shape_z.argmax(2) | |
indices_gather = indices[..., None, None].repeat(1, 1, 1, 3) | |
point_a = seq_shape.gather(2, indices_gather).squeeze(2) | |
seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + 1e6 * (1 - seq_shape_pos_y_mask) | |
indices = seq_shape_z.argmin(2) | |
indices_gather = indices[..., None, None].repeat(1, 1, 1, 3) | |
point_b = seq_shape.gather(2, indices_gather).squeeze(2) | |
else: | |
raise NotImplementedError | |
# place points on the symmetry axis | |
point_a[..., 0] = 0 | |
point_b[..., 0] = 0 | |
mid_point = seq_shape.mean(2) # Shape: (B, F, 3) | |
# place points on the symmetry axis | |
mid_point[..., 0] = 0 | |
if n_leg_bones > 0: | |
mid_point[..., 1] += 0.5 # lift mid point a bit higher if there are legs | |
assert n_body_bones % 2 == 0 | |
n_joints = n_body_bones + 1 | |
blend = torch.linspace(0., 1., math.ceil(n_joints / 2), device=point_a.device)[None, None, :, None] # Shape: (1, 1, (n_joints + 1) / 2, 1) | |
joints_a = point_a[:, :, None, :] * (1 - blend) + mid_point[:, :, None, :] * blend | |
# point_a to mid_point | |
joints_b = point_b[:, :, None, :] * blend + mid_point[:, :, None, :] * (1 - blend) | |
# mid_point to point_b | |
joints = torch.cat([joints_a[:, :, :-1], joints_b], 2) # Shape: (B, F, n_joints, 3) | |
if body_bones_type == 'mine': | |
joints = lift_points_mesh(joints, seq_shape) | |
# build bones and kinematic chain starting from leaf bones | |
if compute_kinematic_chain: | |
aux = {} | |
half_n_body_bones = n_body_bones // 2 | |
bones_to_joints = [] | |
kinematic_chain = [] | |
bone_idx = 0 | |
# bones from point_a to mid_point | |
dependent_bones = [] | |
for i in range(half_n_body_bones): | |
bones_to_joints += [(i + 1, i)] | |
kinematic_chain = [(bone_idx, dependent_bones)] + kinematic_chain # parent is always in the front | |
dependent_bones = dependent_bones + [bone_idx] | |
bone_idx += 1 | |
# bones from point_b to mid_point | |
dependent_bones = [] | |
for i in range(n_body_bones - 1, half_n_body_bones - 1, -1): | |
bones_to_joints += [(i, i + 1)] | |
kinematic_chain = [(bone_idx, dependent_bones)] + kinematic_chain # parent is always in the front | |
dependent_bones = dependent_bones + [bone_idx] | |
bone_idx += 1 | |
aux['bones_to_joints'] = bones_to_joints | |
else: | |
bones_to_joints = aux['bones_to_joints'] | |
kinematic_chain = aux['kinematic_chain'] | |
bones_pred = _joints_to_bones(joints, bones_to_joints) | |
if n_leg_bones > 0: | |
assert n_legs == 4 | |
# attach four legs | |
# y, z is symetry plain | |
# y axis is up | |
# | |
# top down view: | |
# | |
# | | |
# 2 | 1 | |
# -------|------ > x | |
# 3 | 0 | |
# ⌄ | |
# z | |
# | |
# find a point with the lowest y in each quadrant | |
# max_dist = (point_a - point_b).norm(p=2, dim=-1) | |
xs, ys, zs = seq_shape.unbind(-1) | |
# if bone_y_threshold is not None: | |
# flags = (ys < bone_y_threshold) | |
# x_margin = (xs[flags].quantile(0.95) - xs[flags].quantile(0.05)) * 0.2 | |
# x0 = xs[flags].quantile(0.5) | |
# else: | |
# x_margin = (xs.quantile(0.95) - xs.quantile(0.05)) * 0.2 | |
# x0 = 0 | |
if bone_y_threshold is None: | |
x_margin = (xs.quantile(0.95) - xs.quantile(0.05)) * 0.2 | |
x0 = 0 | |
quadrant0 = torch.logical_and(xs - x0 > x_margin, zs > 0) | |
quadrant1 = torch.logical_and(xs - x0 > x_margin, zs < 0) | |
quadrant2 = torch.logical_and(xs - x0 < -x_margin, zs < 0) | |
quadrant3 = torch.logical_and(xs - x0 < -x_margin, zs > 0) | |
else: | |
y_threshold = ys.quantile(bone_y_threshold) | |
flags = (ys < y_threshold) | |
x0 = xs[flags].quantile(0.5) | |
z0 = zs[flags].quantile(0.5) | |
x_margin = (xs[flags].quantile(0.95) - xs[flags].quantile(0.05)) * 0.2 | |
z_margin = (zs[flags].quantile(0.95) - zs[flags].quantile(0.05)) * 0.2 | |
# quadrant0 = torch.logical_and(xs - x0 > x_margin, zs > z0) | |
# quadrant1 = torch.logical_and(xs - x0 > x_margin, zs < z0) | |
# quadrant2 = torch.logical_and(xs - x0 < -x_margin, zs < z0) | |
# quadrant3 = torch.logical_and(xs - x0 < -x_margin, zs > z0) | |
quadrant0 = torch.logical_and(xs - x0 > x_margin, zs - z0 > z_margin) | |
quadrant1 = torch.logical_and(xs - x0 > x_margin, zs < z0) | |
quadrant2 = torch.logical_and(xs - x0 < -x_margin, zs < z0) | |
quadrant3 = torch.logical_and(xs - x0 < -x_margin, zs - z0 > z_margin) | |
def find_leg_in_quadrant(quadrant, n_bones, body_bone_idx, body_bones_type=None): | |
all_joints = torch.zeros([seq_shape.shape[0], seq_shape.shape[1], n_bones + 1, 3], dtype=seq_shape.dtype, device=seq_shape.device) | |
for b in range(seq_shape.shape[0]): | |
for f in range(seq_shape.shape[1]): | |
# find a point with the lowest y | |
quadrant_points = seq_shape[b, f][quadrant[b, f]] | |
if len(quadrant_points.view(-1)) < 1: | |
import pdb; pdb.set_trace() | |
idx = torch.argmin(quadrant_points[:, 1]) ## lowest y | |
foot = quadrant_points[idx] | |
# find closest point on the body joints (the end joint of the bone) | |
if body_bone_idx is None: | |
body_bone_idx_1 = int(torch.argmin(torch.norm(bones_pred[b, f, :, 1] - foot[None], dim=1))) | |
body_bone_idx_2 = int(torch.argmin((bones_pred[b, f, :, 1, 2] - foot[None, 2]).abs())) # closest in z axis | |
# if the body_bone_idx_1 is 4, then should use body_bone_idx_2 | |
# body_bone_idx = body_bone_idx_1 if body_bone_idx_1 != 4 else body_bone_idx_2 # this is used for distribution loss caused tilt effect | |
body_bone_idx = body_bone_idx_2 | |
body_joint = bones_pred[b, f, body_bone_idx, 1] | |
# create bone structure from the foot to the body joint | |
blend = torch.linspace(0., 1., n_bones + 1, device=seq_shape.device)[:, None] | |
joints = foot[None] * (1 - blend) + body_joint[None] * blend | |
all_joints[b, f] = joints | |
return all_joints, body_bone_idx | |
quadrants = [quadrant0, quadrant1, quadrant2, quadrant3] | |
# body_bone_idxs = [None, None, None, None] | |
# body_bone_idxs = [3, 5, 5, 3] | |
# body_bone_idxs = [2, 6, 6, 2] | |
# body_bone_idxs = [2, 7, 7, 2] | |
# body_bone_idxs = [3, 6, 6, 3] | |
if body_bone_idx_preset == [0, 0, 0, 0]: | |
body_bone_idx_preset = [None, None, None, None] | |
body_bone_idxs = body_bone_idx_preset | |
start_bone_idx = n_body_bones | |
all_leg_bones = [] | |
if compute_kinematic_chain: | |
leg_auxs = [] | |
else: | |
leg_auxs = aux['legs'] | |
for i, quadrant in enumerate(quadrants): | |
if compute_kinematic_chain: | |
leg_i_aux = {} | |
body_bone_idx = body_bone_idxs[i] | |
if i == 2: | |
body_bone_idx = body_bone_idxs[1] | |
elif i == 3: | |
body_bone_idx = body_bone_idxs[0] | |
leg_joints, body_bone_idx = find_leg_in_quadrant(quadrant, n_leg_bones, body_bone_idx=body_bone_idx, body_bones_type=body_bones_type) | |
body_bone_idxs[i] = body_bone_idx | |
leg_bones_to_joints, leg_kinematic_chain, leg_bone_idxs = build_kinematic_chain(n_leg_bones, start_bone_idx=start_bone_idx) | |
kinematic_chain = update_body_kinematic_chain(kinematic_chain, leg_kinematic_chain, body_bone_idx, leg_bone_idxs, attach_legs_to_body=attach_legs_to_body) | |
leg_i_aux['body_bone_idx'] = body_bone_idx | |
leg_i_aux['leg_bones_to_joints'] = leg_bones_to_joints | |
start_bone_idx += n_leg_bones | |
else: | |
leg_i_aux = leg_auxs[i] | |
body_bone_idx = leg_i_aux['body_bone_idx'] | |
leg_joints, _ = find_leg_in_quadrant(quadrant, n_leg_bones, body_bone_idx, body_bones_type=body_bones_type) | |
leg_bones_to_joints = leg_i_aux['leg_bones_to_joints'] | |
leg_bones = _joints_to_bones(leg_joints, leg_bones_to_joints) | |
all_leg_bones += [leg_bones] | |
if compute_kinematic_chain: | |
leg_auxs += [leg_i_aux] | |
all_bones = [bones_pred] + all_leg_bones | |
all_bones = torch.cat(all_bones, dim=2) | |
else: | |
all_bones = bones_pred | |
if compute_kinematic_chain: | |
aux['kinematic_chain'] = kinematic_chain | |
if n_leg_bones > 0: | |
aux['legs'] = leg_auxs | |
return all_bones.detach(), kinematic_chain, aux | |
else: | |
return all_bones.detach() | |
def _estimate_bone_rotation(forward): | |
""" | |
(0, 0, 1) = matmul(b, R^(-1)) | |
assumes y, z is a symmetry plane | |
returns R | |
""" | |
forward = nn.functional.normalize(forward, p=2, dim=-1) | |
right = torch.FloatTensor([[1, 0, 0]]).to(forward.device) | |
right = right.expand_as(forward) | |
up = torch.cross(forward, right, dim=-1) | |
up = nn.functional.normalize(up, p=2, dim=-1) | |
right = torch.cross(up, forward, dim=-1) | |
up = nn.functional.normalize(up, p=2, dim=-1) | |
R = torch.stack([right, up, forward], dim=-1) | |
return R | |
def children_to_parents(kinematic_tree): | |
""" | |
converts list [(bone1, [children1, ...]), (bone2, [children1, ...]), ...] to [(bone1, [parent1, ...]), ....] | |
""" | |
parents = [] | |
for bone_id, _ in kinematic_tree: | |
# establish a kinematic chain with current bone as the leaf bone | |
parents_ids = [parent_id for parent_id, children in kinematic_tree if bone_id in children] | |
parents += [(bone_id, parents_ids)] | |
return parents | |
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: | |
""" [Borrowed from PyTorch3D] | |
Return the rotation matrices for one of the rotations about an axis | |
of which Euler angles describe, for each value of the angle given. | |
Args: | |
axis: Axis label "X" or "Y or "Z". | |
angle: any shape tensor of Euler angles in radians | |
Returns: | |
Rotation matrices as tensor of shape (..., 3, 3). | |
""" | |
cos = torch.cos(angle) | |
sin = torch.sin(angle) | |
one = torch.ones_like(angle) | |
zero = torch.zeros_like(angle) | |
if axis == "X": | |
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) | |
elif axis == "Y": | |
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) | |
elif axis == "Z": | |
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) | |
else: | |
raise ValueError("letter must be either X, Y or Z.") | |
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) | |
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor: | |
""" [Borrowed from PyTorch3D] | |
Convert rotations given as Euler angles in radians to rotation matrices. | |
Args: | |
euler_angles: Euler angles in radians as tensor of shape (..., 3). | |
convention: Convention string of three uppercase letters from | |
{"X", "Y", and "Z"}. | |
Returns: | |
Rotation matrices as tensor of shape (..., 3, 3). | |
""" | |
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: | |
raise ValueError("Invalid input euler angles.") | |
if len(convention) != 3: | |
raise ValueError("Convention must have 3 letters.") | |
if convention[1] in (convention[0], convention[2]): | |
raise ValueError(f"Invalid convention {convention}.") | |
for letter in convention: | |
if letter not in ("X", "Y", "Z"): | |
raise ValueError(f"Invalid letter {letter} in convention string.") | |
matrices = [ | |
_axis_angle_rotation(c, e) | |
for c, e in zip(convention, torch.unbind(euler_angles, -1)) | |
] | |
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2]) | |
def _prepare_transform_mtx(rotation=None, translation=None): | |
mtx = torch.eye(4)[None] | |
if rotation is not None: | |
if len(mtx) != len(rotation): | |
assert len(mtx) == 1 | |
mtx = mtx.repeat(len(rotation), 1, 1) | |
mtx = mtx.to(rotation.device) | |
mtx[:, :3, :3] = rotation | |
if translation is not None: | |
if len(mtx) != len(translation): | |
assert len(mtx) == 1 | |
mtx = mtx.repeat(len(translation), 1, 1) | |
mtx = mtx.to(translation.device) | |
mtx[:, :3, 3] = translation | |
return mtx | |
def _invert_transform_mtx(mtx): | |
inv_mtx = torch.eye(4)[None].repeat(len(mtx), 1, 1).to(mtx.device) | |
rotation = mtx[:, :3, :3] | |
translation = mtx[:, :3, 3] | |
inv_mtx[:, :3, :3] = rotation.transpose(1, 2) | |
inv_mtx[:, :3, 3] = -torch.bmm(rotation.transpose(1, 2), translation.unsqueeze(-1)).squeeze(-1) | |
return inv_mtx | |
def skinning(v_pos, bones_pred, kinematic_tree, deform_params, output_posed_bones=False, temperature=1): | |
""" | |
""" | |
device = deform_params.device | |
batch_size, num_frames = deform_params.shape[:2] | |
shape = v_pos | |
# Associate vertices to bones | |
vertices_to_bones = _compute_vertices_to_bones_weights(bones_pred, shape.detach(), temperature=temperature) # Shape: (num_bones, B, F, V) | |
rots_pred = deform_params | |
# Rotate vertices based on bone assignments | |
frame_shape_pred = [] | |
if output_posed_bones: | |
posed_bones = bones_pred.clone() | |
if posed_bones.shape[0] != batch_size or posed_bones.shape[1] != num_frames: | |
posed_bones = posed_bones.repeat(batch_size, num_frames, 1, 1, 1) # Shape: (B, F, num_bones, 2, 3) | |
# Go through each bone | |
for bone_id, _ in kinematic_tree: | |
# Establish a kinematic chain with current bone as the leaf bone | |
## TODO: this assumes the parents is always in the front of the list | |
parents_ids = [parent_id for parent_id, children in kinematic_tree if bone_id in children] | |
chain_ids = parents_ids + [bone_id] | |
# Chain from leaf to root | |
chain_ids = chain_ids[::-1] | |
# Go through the kinematic chain from leaf to root and compose transformation | |
transform_mtx = torch.eye(4)[None].to(device) | |
for i in chain_ids: | |
# Establish transformation | |
rest_joint = bones_pred[:, :, i, 0, :].view(-1, 3) | |
rest_bone_vector = bones_pred[:, :, i, 1, :] - bones_pred[:, :, i, 0, :] | |
rest_bone_rot = _estimate_bone_rotation(rest_bone_vector.view(-1, 3)) | |
rest_bone_mtx = _prepare_transform_mtx(rotation=rest_bone_rot, translation=rest_joint) | |
rest_bone_inv_mtx = _invert_transform_mtx(rest_bone_mtx) | |
# Transform to the bone local frame | |
transform_mtx = torch.matmul(rest_bone_inv_mtx, transform_mtx) | |
# Rotate the mesh in the bone local frame | |
rot_pred = rots_pred[:, :, i] | |
rot_pred_mat = euler_angles_to_matrix(rot_pred.view(-1, 3), convention='XYZ') | |
rot_pred_mtx = _prepare_transform_mtx(rotation=rot_pred_mat, translation=None) | |
transform_mtx = torch.matmul(rot_pred_mtx, transform_mtx) | |
# Transform to the world frame | |
transform_mtx = torch.matmul(rest_bone_mtx, transform_mtx) | |
# Transform vertices | |
shape4 = rearrange(torch.cat([shape, torch.ones_like(shape[...,:1])], dim=-1), 'b f ... -> (b f) ...') | |
seq_shape_bone = torch.matmul(shape4, transform_mtx.transpose(-2, -1))[..., :3] | |
seq_shape_bone = rearrange(seq_shape_bone, '(b f) ... -> b f ...', b=batch_size, f=num_frames) | |
if output_posed_bones: | |
bones4 = torch.cat([rearrange(posed_bones[:, :, bone_id], 'b f ... -> (b f) ...'), torch.ones(batch_size * num_frames, 2, 1).to(device)], dim=-1) | |
posed_bones[:, :, bone_id] = rearrange(torch.matmul(bones4, transform_mtx.transpose(-2, -1))[..., :3], '(b f) ... -> b f ...', b=batch_size, f=num_frames) | |
# Transform mesh with weights | |
frame_shape_pred += [vertices_to_bones[bone_id, ..., None] * seq_shape_bone] | |
frame_shape_pred = sum(frame_shape_pred) | |
aux = {} | |
aux['bones_pred'] = bones_pred | |
aux['vertices_to_bones'] = vertices_to_bones | |
if output_posed_bones: | |
aux['posed_bones'] = posed_bones | |
return frame_shape_pred, aux | |