3DFauna_demo / video3d /utils /skinning_v4.py
kyleleey
first commit
98a77e0
raw
history blame
23.7 kB
import math
import torch
import torch.nn as nn
from . import geometry
from einops import rearrange
import itertools
def _joints_to_bones(joints, bones_idxs):
bones = []
for a, b in bones_idxs:
bones += [torch.stack([joints[:, :, a, :], joints[:, :, b, :]], dim=2)]
bones = torch.stack(bones, dim=2)
return bones
def _compute_vertices_to_bones_weights(bones_pred, seq_shape_pred, temperature=1):
vertices_to_bones = []
for i in range(bones_pred.shape[2]):
vertices_to_bones += [geometry.line_segment_distance(bones_pred[:, :, i, 0], bones_pred[:, :, i, 1], seq_shape_pred)]
# vertices_to_bones = nn.functional.softmax(1 / torch.stack(vertices_to_bones) / temperature, dim=0)
vertices_to_bones = nn.functional.softmax(-torch.stack(vertices_to_bones) / temperature, dim=0)
return vertices_to_bones
def build_kinematic_chain(n_bones, start_bone_idx):
# build bones and kinematic chain starting from leaf bone (body joint)
bones_to_joints = []
kinematic_chain = []
bone_idx = start_bone_idx
# bones from leaf to root
dependent_bones = []
for i in range(n_bones):
bones_to_joints += [(i + 1, i)]
kinematic_chain = [(bone_idx, dependent_bones)] + kinematic_chain # parent is always in the front
dependent_bones = dependent_bones + [bone_idx]
bone_idx += 1
return bones_to_joints, kinematic_chain, dependent_bones
def update_body_kinematic_chain(kinematic_chain, leg_kinematic_chain, body_bone_idx, leg_bone_idxs, attach_legs_to_body=True):
if attach_legs_to_body:
for bone_idx, dependent_bones in kinematic_chain:
if bone_idx == body_bone_idx or body_bone_idx in dependent_bones:
dependent_bones += leg_bone_idxs
kinematic_chain = kinematic_chain + leg_kinematic_chain # parent is always in the front
return kinematic_chain
def lift_points_mesh(points, seq_shape, size_aspect=0.5):
"""
for a set of points that's generated by linear interpolation, lift them in y-axis to match the actual bones
this operates on all the joint points except for the first and last one
"""
points_to_lift = points[:, :, 1:-1, :]
points_z_range_max = points_to_lift[..., 2] - size_aspect * (points_to_lift[..., 2] - points[:, :, :-2, 2])
points_z_range_min = points_to_lift[..., 2] - size_aspect * (points_to_lift[..., 2] - points[:, :, 2:, 2])
points_z_range_min = points_z_range_min.unsqueeze(-1).expand(-1, -1, -1, seq_shape.shape[-2])
points_z_range_max = points_z_range_max.unsqueeze(-1).expand(-1, -1, -1, seq_shape.shape[-2])
valid_points = seq_shape.unsqueeze(2).expand(-1, -1, points_to_lift.shape[-2], -1, -1)
valid_idx_1 = valid_points[..., 2] > points_z_range_min
valid_idx_2 = valid_points[..., 2] < points_z_range_max
valid_idx = valid_idx_1 * valid_idx_2
valid_idx = valid_idx.float()
valid_y = valid_points[..., 1] * valid_idx + (-1e6) * (1 - valid_idx)
valid_y, _ = valid_y.max(dim=-1)
is_valid = valid_y != (-1e6)
is_valid = is_valid.float()
points[:, :, 1:-1, 1] = points[:, :, 1:-1, 1] * (1-is_valid) + valid_y * is_valid
return points
@torch.no_grad()
def estimate_bones(seq_shape, n_body_bones, resample=False, n_legs=4, n_leg_bones=0, body_bones_type='z_minmax', compute_kinematic_chain=True, aux=None, attach_legs_to_body=True, bone_y_threshold=None, body_bone_idx_preset=[3, 5, 5, 3]):
"""
Estimate the position and structure of bones given the mesh vertex positions.
Args:
seq_shape: a tensor of shape (B, F, V, 3), the batched position of mesh vertices.
n_body_bones: an integer, the desired number of bones.
Returns:
(bones_pred, kinematic_chain) where
bones_pred: a tensor of shape (B, F, num_bones, 2, 3)
kinematic_chain: a list of tuples of length n_body_bones; for each tuple, the first element is the bone index while
the second element is a list of bones indices of dependent bones.
"""
# preprocess shape
if resample:
b, _, n, _ = seq_shape.shape
seq_shape = geometry.sample_farthest_points(rearrange(seq_shape, 'b f n d -> (b f) d n'), n // 4)
seq_shape = rearrange(seq_shape, '(b f) d n -> b f n d', b=b)
if body_bones_type == 'max_distance':
raise NotImplementedError
# find two farthest points
# x is the symmetry plane, ignore it
# dists = torch.linalg.norm(seq_shape[:, :, None, :, 1:] - seq_shape[:, :, :, None, 1:], dim=-1) # Shape: (B, F, V, V)
# num_verts = dists.shape[-1]
# indices_flat = rearrange(dists, 'b f d1 d2 -> b f (d1 d2)').argmax(2) # Shape: (B, F)
# indices = torch.cat([(indices_flat // num_verts)[..., None], (indices_flat % num_verts)[..., None]], dim=2) # Shape: (B, F, 2)
# indices_gather = indices[..., None].repeat(1, 1, 1, 3) # Shape: (B, F, 2, 3)
# points = seq_shape.gather(2, indices_gather) # Shape: (B, F, 2, 3)
# fix the points order along z axis
# z_coordinate = points[:, :, :, 2] # Shape: (B, F, 2)
# front = z_coordinate < 0
# point_a = rearrange(points[~front], '(b f) d -> b f d', b=seq_shape.shape[0]) # Shape: (B, F, 3)
# point_b = rearrange(points[front], '(b f) d -> b f d', b=seq_shape.shape[0]) # Shape: (B, F, 3)
elif body_bones_type == 'z_minmax':
indices_max, indices_min = seq_shape[..., 2].argmax(dim=2), seq_shape[..., 2].argmin(dim=2)
indices = torch.cat([indices_max[..., None], indices_min[..., None]], dim=2)
indices_gather = indices[..., None].repeat(1, 1, 1, 3) # Shape: (B, F, 2, 3)
points = seq_shape.gather(2, indices_gather)
point_a = points[:, :, 0, :]
point_b = points[:, :, 1, :]
elif body_bones_type == 'z_minmax_y+':
## TODO: mean may not be very robust, as inside is noisy
mid_point = seq_shape.mean(2)
seq_shape_pos_y_mask = (seq_shape[:, :, :, 1] > (mid_point[:, :, None, 1] - 0.5)).float() # y higher than midpoint
seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + (-1e6) * (1 - seq_shape_pos_y_mask)
indices = seq_shape_z.argmax(2)
indices_gather = indices[..., None, None].repeat(1, 1, 1, 3)
point_a = seq_shape.gather(2, indices_gather).squeeze(2)
seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + 1e6 * (1 - seq_shape_pos_y_mask)
indices = seq_shape_z.argmin(2)
indices_gather = indices[..., None, None].repeat(1, 1, 1, 3)
point_b = seq_shape.gather(2, indices_gather).squeeze(2)
elif body_bones_type == 'mine':
## TODO: mean may not be very robust, as inside is noisy
mid_point = seq_shape.mean(2)
seq_shape_pos_y_mask = (seq_shape[:, :, :, 1] > (mid_point[:, :, None, 1] - 0.5)).float() # y higher than midpoint
seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + (-1e6) * (1 - seq_shape_pos_y_mask)
indices = seq_shape_z.argmax(2)
indices_gather = indices[..., None, None].repeat(1, 1, 1, 3)
point_a = seq_shape.gather(2, indices_gather).squeeze(2)
seq_shape_z = seq_shape[:, :, :, 2] * seq_shape_pos_y_mask + 1e6 * (1 - seq_shape_pos_y_mask)
indices = seq_shape_z.argmin(2)
indices_gather = indices[..., None, None].repeat(1, 1, 1, 3)
point_b = seq_shape.gather(2, indices_gather).squeeze(2)
else:
raise NotImplementedError
# place points on the symmetry axis
point_a[..., 0] = 0
point_b[..., 0] = 0
mid_point = seq_shape.mean(2) # Shape: (B, F, 3)
# place points on the symmetry axis
mid_point[..., 0] = 0
if n_leg_bones > 0:
mid_point[..., 1] += 0.5 # lift mid point a bit higher if there are legs
assert n_body_bones % 2 == 0
n_joints = n_body_bones + 1
blend = torch.linspace(0., 1., math.ceil(n_joints / 2), device=point_a.device)[None, None, :, None] # Shape: (1, 1, (n_joints + 1) / 2, 1)
joints_a = point_a[:, :, None, :] * (1 - blend) + mid_point[:, :, None, :] * blend
# point_a to mid_point
joints_b = point_b[:, :, None, :] * blend + mid_point[:, :, None, :] * (1 - blend)
# mid_point to point_b
joints = torch.cat([joints_a[:, :, :-1], joints_b], 2) # Shape: (B, F, n_joints, 3)
if body_bones_type == 'mine':
joints = lift_points_mesh(joints, seq_shape)
# build bones and kinematic chain starting from leaf bones
if compute_kinematic_chain:
aux = {}
half_n_body_bones = n_body_bones // 2
bones_to_joints = []
kinematic_chain = []
bone_idx = 0
# bones from point_a to mid_point
dependent_bones = []
for i in range(half_n_body_bones):
bones_to_joints += [(i + 1, i)]
kinematic_chain = [(bone_idx, dependent_bones)] + kinematic_chain # parent is always in the front
dependent_bones = dependent_bones + [bone_idx]
bone_idx += 1
# bones from point_b to mid_point
dependent_bones = []
for i in range(n_body_bones - 1, half_n_body_bones - 1, -1):
bones_to_joints += [(i, i + 1)]
kinematic_chain = [(bone_idx, dependent_bones)] + kinematic_chain # parent is always in the front
dependent_bones = dependent_bones + [bone_idx]
bone_idx += 1
aux['bones_to_joints'] = bones_to_joints
else:
bones_to_joints = aux['bones_to_joints']
kinematic_chain = aux['kinematic_chain']
bones_pred = _joints_to_bones(joints, bones_to_joints)
if n_leg_bones > 0:
assert n_legs == 4
# attach four legs
# y, z is symetry plain
# y axis is up
#
# top down view:
#
# |
# 2 | 1
# -------|------ > x
# 3 | 0
# ⌄
# z
#
# find a point with the lowest y in each quadrant
# max_dist = (point_a - point_b).norm(p=2, dim=-1)
xs, ys, zs = seq_shape.unbind(-1)
# if bone_y_threshold is not None:
# flags = (ys < bone_y_threshold)
# x_margin = (xs[flags].quantile(0.95) - xs[flags].quantile(0.05)) * 0.2
# x0 = xs[flags].quantile(0.5)
# else:
# x_margin = (xs.quantile(0.95) - xs.quantile(0.05)) * 0.2
# x0 = 0
if bone_y_threshold is None:
x_margin = (xs.quantile(0.95) - xs.quantile(0.05)) * 0.2
x0 = 0
quadrant0 = torch.logical_and(xs - x0 > x_margin, zs > 0)
quadrant1 = torch.logical_and(xs - x0 > x_margin, zs < 0)
quadrant2 = torch.logical_and(xs - x0 < -x_margin, zs < 0)
quadrant3 = torch.logical_and(xs - x0 < -x_margin, zs > 0)
else:
y_threshold = ys.quantile(bone_y_threshold)
flags = (ys < y_threshold)
x0 = xs[flags].quantile(0.5)
z0 = zs[flags].quantile(0.5)
x_margin = (xs[flags].quantile(0.95) - xs[flags].quantile(0.05)) * 0.2
z_margin = (zs[flags].quantile(0.95) - zs[flags].quantile(0.05)) * 0.2
# quadrant0 = torch.logical_and(xs - x0 > x_margin, zs > z0)
# quadrant1 = torch.logical_and(xs - x0 > x_margin, zs < z0)
# quadrant2 = torch.logical_and(xs - x0 < -x_margin, zs < z0)
# quadrant3 = torch.logical_and(xs - x0 < -x_margin, zs > z0)
quadrant0 = torch.logical_and(xs - x0 > x_margin, zs - z0 > z_margin)
quadrant1 = torch.logical_and(xs - x0 > x_margin, zs < z0)
quadrant2 = torch.logical_and(xs - x0 < -x_margin, zs < z0)
quadrant3 = torch.logical_and(xs - x0 < -x_margin, zs - z0 > z_margin)
def find_leg_in_quadrant(quadrant, n_bones, body_bone_idx, body_bones_type=None):
all_joints = torch.zeros([seq_shape.shape[0], seq_shape.shape[1], n_bones + 1, 3], dtype=seq_shape.dtype, device=seq_shape.device)
for b in range(seq_shape.shape[0]):
for f in range(seq_shape.shape[1]):
# find a point with the lowest y
quadrant_points = seq_shape[b, f][quadrant[b, f]]
if len(quadrant_points.view(-1)) < 1:
import pdb; pdb.set_trace()
idx = torch.argmin(quadrant_points[:, 1]) ## lowest y
foot = quadrant_points[idx]
# find closest point on the body joints (the end joint of the bone)
if body_bone_idx is None:
body_bone_idx_1 = int(torch.argmin(torch.norm(bones_pred[b, f, :, 1] - foot[None], dim=1)))
body_bone_idx_2 = int(torch.argmin((bones_pred[b, f, :, 1, 2] - foot[None, 2]).abs())) # closest in z axis
# if the body_bone_idx_1 is 4, then should use body_bone_idx_2
# body_bone_idx = body_bone_idx_1 if body_bone_idx_1 != 4 else body_bone_idx_2 # this is used for distribution loss caused tilt effect
body_bone_idx = body_bone_idx_2
body_joint = bones_pred[b, f, body_bone_idx, 1]
# create bone structure from the foot to the body joint
blend = torch.linspace(0., 1., n_bones + 1, device=seq_shape.device)[:, None]
joints = foot[None] * (1 - blend) + body_joint[None] * blend
all_joints[b, f] = joints
return all_joints, body_bone_idx
quadrants = [quadrant0, quadrant1, quadrant2, quadrant3]
# body_bone_idxs = [None, None, None, None]
# body_bone_idxs = [3, 5, 5, 3]
# body_bone_idxs = [2, 6, 6, 2]
# body_bone_idxs = [2, 7, 7, 2]
# body_bone_idxs = [3, 6, 6, 3]
if body_bone_idx_preset == [0, 0, 0, 0]:
body_bone_idx_preset = [None, None, None, None]
body_bone_idxs = body_bone_idx_preset
start_bone_idx = n_body_bones
all_leg_bones = []
if compute_kinematic_chain:
leg_auxs = []
else:
leg_auxs = aux['legs']
for i, quadrant in enumerate(quadrants):
if compute_kinematic_chain:
leg_i_aux = {}
body_bone_idx = body_bone_idxs[i]
if i == 2:
body_bone_idx = body_bone_idxs[1]
elif i == 3:
body_bone_idx = body_bone_idxs[0]
leg_joints, body_bone_idx = find_leg_in_quadrant(quadrant, n_leg_bones, body_bone_idx=body_bone_idx, body_bones_type=body_bones_type)
body_bone_idxs[i] = body_bone_idx
leg_bones_to_joints, leg_kinematic_chain, leg_bone_idxs = build_kinematic_chain(n_leg_bones, start_bone_idx=start_bone_idx)
kinematic_chain = update_body_kinematic_chain(kinematic_chain, leg_kinematic_chain, body_bone_idx, leg_bone_idxs, attach_legs_to_body=attach_legs_to_body)
leg_i_aux['body_bone_idx'] = body_bone_idx
leg_i_aux['leg_bones_to_joints'] = leg_bones_to_joints
start_bone_idx += n_leg_bones
else:
leg_i_aux = leg_auxs[i]
body_bone_idx = leg_i_aux['body_bone_idx']
leg_joints, _ = find_leg_in_quadrant(quadrant, n_leg_bones, body_bone_idx, body_bones_type=body_bones_type)
leg_bones_to_joints = leg_i_aux['leg_bones_to_joints']
leg_bones = _joints_to_bones(leg_joints, leg_bones_to_joints)
all_leg_bones += [leg_bones]
if compute_kinematic_chain:
leg_auxs += [leg_i_aux]
all_bones = [bones_pred] + all_leg_bones
all_bones = torch.cat(all_bones, dim=2)
else:
all_bones = bones_pred
if compute_kinematic_chain:
aux['kinematic_chain'] = kinematic_chain
if n_leg_bones > 0:
aux['legs'] = leg_auxs
return all_bones.detach(), kinematic_chain, aux
else:
return all_bones.detach()
def _estimate_bone_rotation(forward):
"""
(0, 0, 1) = matmul(b, R^(-1))
assumes y, z is a symmetry plane
returns R
"""
forward = nn.functional.normalize(forward, p=2, dim=-1)
right = torch.FloatTensor([[1, 0, 0]]).to(forward.device)
right = right.expand_as(forward)
up = torch.cross(forward, right, dim=-1)
up = nn.functional.normalize(up, p=2, dim=-1)
right = torch.cross(up, forward, dim=-1)
up = nn.functional.normalize(up, p=2, dim=-1)
R = torch.stack([right, up, forward], dim=-1)
return R
def children_to_parents(kinematic_tree):
"""
converts list [(bone1, [children1, ...]), (bone2, [children1, ...]), ...] to [(bone1, [parent1, ...]), ....]
"""
parents = []
for bone_id, _ in kinematic_tree:
# establish a kinematic chain with current bone as the leaf bone
parents_ids = [parent_id for parent_id, children in kinematic_tree if bone_id in children]
parents += [(bone_id, parents_ids)]
return parents
def _axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor:
""" [Borrowed from PyTorch3D]
Return the rotation matrices for one of the rotations about an axis
of which Euler angles describe, for each value of the angle given.
Args:
axis: Axis label "X" or "Y or "Z".
angle: any shape tensor of Euler angles in radians
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
cos = torch.cos(angle)
sin = torch.sin(angle)
one = torch.ones_like(angle)
zero = torch.zeros_like(angle)
if axis == "X":
R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos)
elif axis == "Y":
R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos)
elif axis == "Z":
R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one)
else:
raise ValueError("letter must be either X, Y or Z.")
return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3))
def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str) -> torch.Tensor:
""" [Borrowed from PyTorch3D]
Convert rotations given as Euler angles in radians to rotation matrices.
Args:
euler_angles: Euler angles in radians as tensor of shape (..., 3).
convention: Convention string of three uppercase letters from
{"X", "Y", and "Z"}.
Returns:
Rotation matrices as tensor of shape (..., 3, 3).
"""
if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3:
raise ValueError("Invalid input euler angles.")
if len(convention) != 3:
raise ValueError("Convention must have 3 letters.")
if convention[1] in (convention[0], convention[2]):
raise ValueError(f"Invalid convention {convention}.")
for letter in convention:
if letter not in ("X", "Y", "Z"):
raise ValueError(f"Invalid letter {letter} in convention string.")
matrices = [
_axis_angle_rotation(c, e)
for c, e in zip(convention, torch.unbind(euler_angles, -1))
]
return torch.matmul(torch.matmul(matrices[0], matrices[1]), matrices[2])
def _prepare_transform_mtx(rotation=None, translation=None):
mtx = torch.eye(4)[None]
if rotation is not None:
if len(mtx) != len(rotation):
assert len(mtx) == 1
mtx = mtx.repeat(len(rotation), 1, 1)
mtx = mtx.to(rotation.device)
mtx[:, :3, :3] = rotation
if translation is not None:
if len(mtx) != len(translation):
assert len(mtx) == 1
mtx = mtx.repeat(len(translation), 1, 1)
mtx = mtx.to(translation.device)
mtx[:, :3, 3] = translation
return mtx
def _invert_transform_mtx(mtx):
inv_mtx = torch.eye(4)[None].repeat(len(mtx), 1, 1).to(mtx.device)
rotation = mtx[:, :3, :3]
translation = mtx[:, :3, 3]
inv_mtx[:, :3, :3] = rotation.transpose(1, 2)
inv_mtx[:, :3, 3] = -torch.bmm(rotation.transpose(1, 2), translation.unsqueeze(-1)).squeeze(-1)
return inv_mtx
def skinning(v_pos, bones_pred, kinematic_tree, deform_params, output_posed_bones=False, temperature=1):
"""
"""
device = deform_params.device
batch_size, num_frames = deform_params.shape[:2]
shape = v_pos
# Associate vertices to bones
vertices_to_bones = _compute_vertices_to_bones_weights(bones_pred, shape.detach(), temperature=temperature) # Shape: (num_bones, B, F, V)
rots_pred = deform_params
# Rotate vertices based on bone assignments
frame_shape_pred = []
if output_posed_bones:
posed_bones = bones_pred.clone()
if posed_bones.shape[0] != batch_size or posed_bones.shape[1] != num_frames:
posed_bones = posed_bones.repeat(batch_size, num_frames, 1, 1, 1) # Shape: (B, F, num_bones, 2, 3)
# Go through each bone
for bone_id, _ in kinematic_tree:
# Establish a kinematic chain with current bone as the leaf bone
## TODO: this assumes the parents is always in the front of the list
parents_ids = [parent_id for parent_id, children in kinematic_tree if bone_id in children]
chain_ids = parents_ids + [bone_id]
# Chain from leaf to root
chain_ids = chain_ids[::-1]
# Go through the kinematic chain from leaf to root and compose transformation
transform_mtx = torch.eye(4)[None].to(device)
for i in chain_ids:
# Establish transformation
rest_joint = bones_pred[:, :, i, 0, :].view(-1, 3)
rest_bone_vector = bones_pred[:, :, i, 1, :] - bones_pred[:, :, i, 0, :]
rest_bone_rot = _estimate_bone_rotation(rest_bone_vector.view(-1, 3))
rest_bone_mtx = _prepare_transform_mtx(rotation=rest_bone_rot, translation=rest_joint)
rest_bone_inv_mtx = _invert_transform_mtx(rest_bone_mtx)
# Transform to the bone local frame
transform_mtx = torch.matmul(rest_bone_inv_mtx, transform_mtx)
# Rotate the mesh in the bone local frame
rot_pred = rots_pred[:, :, i]
rot_pred_mat = euler_angles_to_matrix(rot_pred.view(-1, 3), convention='XYZ')
rot_pred_mtx = _prepare_transform_mtx(rotation=rot_pred_mat, translation=None)
transform_mtx = torch.matmul(rot_pred_mtx, transform_mtx)
# Transform to the world frame
transform_mtx = torch.matmul(rest_bone_mtx, transform_mtx)
# Transform vertices
shape4 = rearrange(torch.cat([shape, torch.ones_like(shape[...,:1])], dim=-1), 'b f ... -> (b f) ...')
seq_shape_bone = torch.matmul(shape4, transform_mtx.transpose(-2, -1))[..., :3]
seq_shape_bone = rearrange(seq_shape_bone, '(b f) ... -> b f ...', b=batch_size, f=num_frames)
if output_posed_bones:
bones4 = torch.cat([rearrange(posed_bones[:, :, bone_id], 'b f ... -> (b f) ...'), torch.ones(batch_size * num_frames, 2, 1).to(device)], dim=-1)
posed_bones[:, :, bone_id] = rearrange(torch.matmul(bones4, transform_mtx.transpose(-2, -1))[..., :3], '(b f) ... -> b f ...', b=batch_size, f=num_frames)
# Transform mesh with weights
frame_shape_pred += [vertices_to_bones[bone_id, ..., None] * seq_shape_bone]
frame_shape_pred = sum(frame_shape_pred)
aux = {}
aux['bones_pred'] = bones_pred
aux['vertices_to_bones'] = vertices_to_bones
if output_posed_bones:
aux['posed_bones'] = posed_bones
return frame_shape_pred, aux