Spaces:
Sleeping
Sleeping
# import pytorch3d | |
import torch | |
from einops import rearrange | |
from torch._C import device | |
def edges_to_sparse_incidence(edges, num_vertices): | |
num_edges = edges.shape[0] | |
row_indexes = torch.arange(num_edges, dtype=torch.long, device=edges.device).repeat_interleave(2) | |
col_indexes = edges.reshape(-1) | |
indexes = torch.stack([row_indexes, col_indexes]) | |
values = torch.FloatTensor([1, -1]).to(edges.device).repeat(num_edges) | |
return torch.sparse.FloatTensor(indexes, values, torch.Size([num_edges, num_vertices])) | |
def compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat): | |
""" | |
Adapted from: | |
https://github.com/kzhou23/shape_pose_disent/blob/a8017c405892c98f52fa9775327172633290b1d8/arap.py#L76 | |
vertices_rest_pose: B x V x D | |
vertices_deformed_pose: B x V x D | |
incidence_mat: E x V | |
""" | |
batch_size, num_vertices, dimensions = vertices_rest_pose.shape | |
vertices = torch.cat((vertices_rest_pose, vertices_deformed_pose), dim=0) | |
# 2B x V x D -> V x (D x 2B) | |
vertices = rearrange(vertices, 'a v d -> v (d a)') | |
# E x V . V x (D x 2B) - > E x (D x 2B) | |
edges = torch.sparse.mm(incidence_mat, vertices) | |
edges = rearrange(edges, 'e (d a) -> a e d', d=dimensions) | |
rest_edges, deformed_edges = torch.split(edges, batch_size, dim=0) | |
edges_outer = torch.matmul(rest_edges[:, :, :, None], deformed_edges[:, :, None, :]) | |
edges_outer = rearrange(edges_outer, 'b e d1 d2 -> e (b d1 d2)') | |
abs_incidence_mat = incidence_mat.clone() | |
abs_incidence_mat._values()[:] = torch.abs(abs_incidence_mat._values()) | |
# transposed S | |
S = torch.sparse.mm(abs_incidence_mat.t(), edges_outer) | |
S = rearrange(S, 'v (b d1 d2) -> b v d2 d1', v=num_vertices, b=batch_size, d1=dimensions, d2=dimensions) | |
# SVD on gpu is extremely slow! https://github.com/pytorch/pytorch/pull/48436 | |
device = S.device | |
U, _, V = torch.svd(S.cpu()) | |
U = U.to(device) | |
V = V.to(device) | |
det_sign = torch.det(torch.matmul(U, V.transpose(-2, -1))) | |
U = torch.cat([U[..., :-1], U[..., -1:] * det_sign[..., None, None]], axis=-1) | |
rotations = torch.matmul(U, V.transpose(-2, -1)) | |
return rotations | |
def compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges): | |
""" | |
vertices_rest_pose: B x V x D | |
vertices_deformed_pose: B x V x D | |
edges: E x 2 | |
""" | |
num_vertices = vertices_rest_pose.shape[1] | |
incidence_mat = edges_to_sparse_incidence(edges, num_vertices) | |
rot = compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat) | |
rot = pytorch3d.transforms.matrix_to_quaternion(rot) | |
return rot | |
def quaternion_normalize(quaternion, eps=1e-12): | |
""" | |
Adapted from tensorflow_graphics | |
Normalizes a quaternion. | |
Note: | |
In the following, A1 to An are optional batch dimensions. | |
Args: | |
quaternion: A tensor of shape `[A1, ..., An, 4]`, where the last dimension | |
represents a quaternion. | |
eps: A lower bound value for the norm that defaults to 1e-12. | |
name: A name for this op that defaults to "quaternion_normalize". | |
Returns: | |
A N-D tensor of shape `[?, ..., ?, 1]` where the quaternion elements have | |
been normalized. | |
Raises: | |
ValueError: If the shape of `quaternion` is not supported. | |
""" | |
return l2_normalize(quaternion, dim=-1, epsilon=eps) | |
def l2_normalize(x, dim=-1, epsilon=1e-12): | |
square_sum = torch.sum(x ** 2, dim=dim, keepdim=True) | |
x_inv_norm = torch.rsqrt(torch.clamp(square_sum, min=epsilon)) | |
return x * x_inv_norm | |
def arap_energy(vertices_rest_pose, | |
vertices_deformed_pose, | |
quaternions, | |
edges, | |
vertex_weight=None, | |
edge_weight=None, | |
conformal_energy=True, | |
aggregate_loss=True): | |
""" | |
Adapted from tensorflow_graphics | |
Estimates an As Conformal As Possible (ACAP) fitting energy. | |
For a given mesh in rest pose, this function evaluates a variant of the ACAP | |
[1] fitting energy for a batch of deformed meshes. The vertex weights and edge | |
weights are defined on the rest pose. | |
The method implemented here is similar to [2], but with an added free variable | |
capturing a scale factor per vertex. | |
[1]: Yusuke Yoshiyasu, Wan-Chun Ma, Eiichi Yoshida, and Fumio Kanehiro. | |
"As-Conformal-As-Possible Surface Registration." Computer Graphics Forum. Vol. | |
33. No. 5. 2014.</br> | |
[2]: Olga Sorkine, and Marc Alexa. | |
"As-rigid-as-possible surface modeling". Symposium on Geometry Processing. | |
Vol. 4. 2007. | |
Note: | |
In the description of the arguments, V corresponds to | |
the number of vertices in the mesh, and E to the number of edges in this | |
mesh. | |
Note: | |
In the following, A1 to An are optional batch dimensions. | |
Args: | |
vertices_rest_pose: A tensor of shape `[V, 3]` containing the position of | |
all the vertices of the mesh in rest pose. | |
vertices_deformed_pose: A tensor of shape `[A1, ..., An, V, 3]` containing | |
the position of all the vertices of the mesh in deformed pose. | |
quaternions: A tensor of shape `[A1, ..., An, V, 4]` defining a rigid | |
transformation to apply to each vertex of the rest pose. See Section 2 | |
from [1] for further details. | |
edges: A tensor of shape `[E, 2]` defining indices of vertices that are | |
connected by an edge. | |
vertex_weight: An optional tensor of shape `[V]` defining the weight | |
associated with each vertex. Defaults to a tensor of ones. | |
edge_weight: A tensor of shape `[E]` defining the weight of edges. Common | |
choices for these weights include uniform weighting, and cotangent | |
weights. Defaults to a tensor of ones. | |
conformal_energy: A `bool` indicating whether each vertex is associated with | |
a scale factor or not. If this parameter is True, scaling information must | |
be encoded in the norm of `quaternions`. If this parameter is False, this | |
function implements the energy described in [2]. | |
aggregate_loss: A `bool` defining whether the returned loss should be an | |
aggregate measure. When True, the mean squared error is returned. When | |
False, returns two losses for every edge of the mesh. | |
name: A name for this op. Defaults to "as_conformal_as_possible_energy". | |
Returns: | |
When aggregate_loss is `True`, returns a tensor of shape `[A1, ..., An]` | |
containing the ACAP energies. When aggregate_loss is `False`, returns a | |
tensor of shape `[A1, ..., An, 2*E]` containing each term of the summation | |
described in the equation 7 of [2]. | |
Raises: | |
ValueError: if the shape of `vertices_rest_pose`, `vertices_deformed_pose`, | |
`quaternions`, `edges`, `vertex_weight`, or `edge_weight` is not supported. | |
""" | |
# with tf.compat.v1.name_scope(name, "as_conformal_as_possible_energy", [ | |
# vertices_rest_pose, vertices_deformed_pose, quaternions, edges, | |
# conformal_energy, vertex_weight, edge_weight | |
# ]): | |
# vertices_rest_pose = tf.convert_to_tensor(value=vertices_rest_pose) | |
# vertices_deformed_pose = tf.convert_to_tensor(value=vertices_deformed_pose) | |
# quaternions = tf.convert_to_tensor(value=quaternions) | |
# edges = tf.convert_to_tensor(value=edges) | |
# if vertex_weight is not None: | |
# vertex_weight = tf.convert_to_tensor(value=vertex_weight) | |
# if edge_weight is not None: | |
# edge_weight = tf.convert_to_tensor(value=edge_weight) | |
# shape.check_static( | |
# tensor=vertices_rest_pose, | |
# tensor_name="vertices_rest_pose", | |
# has_rank=2, | |
# has_dim_equals=(-1, 3)) | |
# shape.check_static( | |
# tensor=vertices_deformed_pose, | |
# tensor_name="vertices_deformed_pose", | |
# has_rank_greater_than=1, | |
# has_dim_equals=(-1, 3)) | |
# shape.check_static( | |
# tensor=quaternions, | |
# tensor_name="quaternions", | |
# has_rank_greater_than=1, | |
# has_dim_equals=(-1, 4)) | |
# shape.compare_batch_dimensions( | |
# tensors=(vertices_deformed_pose, quaternions), | |
# last_axes=(-3, -3), | |
# broadcast_compatible=False) | |
# shape.check_static( | |
# tensor=edges, tensor_name="edges", has_rank=2, has_dim_equals=(-1, 2)) | |
# tensors_with_vertices = [vertices_rest_pose, | |
# vertices_deformed_pose, | |
# quaternions] | |
# names_with_vertices = ["vertices_rest_pose", | |
# "vertices_deformed_pose", | |
# "quaternions"] | |
# axes_with_vertices = [-2, -2, -2] | |
# if vertex_weight is not None: | |
# shape.check_static( | |
# tensor=vertex_weight, tensor_name="vertex_weight", has_rank=1) | |
# tensors_with_vertices.append(vertex_weight) | |
# names_with_vertices.append("vertex_weight") | |
# axes_with_vertices.append(0) | |
# shape.compare_dimensions( | |
# tensors=tensors_with_vertices, | |
# axes=axes_with_vertices, | |
# tensor_names=names_with_vertices) | |
# if edge_weight is not None: | |
# shape.check_static( | |
# tensor=edge_weight, tensor_name="edge_weight", has_rank=1) | |
# shape.compare_dimensions( | |
# tensors=(edges, edge_weight), | |
# axes=(0, 0), | |
# tensor_names=("edges", "edge_weight")) | |
if not conformal_energy: | |
quaternions = quaternion_normalize(quaternions) | |
# Extracts the indices of vertices. | |
indices_i, indices_j = torch.unbind(edges, dim=-1) | |
# Extracts the vertices we need per term. | |
vertices_i_rest = vertices_rest_pose[..., indices_i, :] | |
vertices_j_rest = vertices_rest_pose[..., indices_j, :] | |
vertices_i_deformed = vertices_deformed_pose[..., indices_i, :] | |
vertices_j_deformed = vertices_deformed_pose[..., indices_j, :] | |
# Extracts the weights we need per term. | |
weights_shape = vertices_i_rest.shape[-2] | |
if vertex_weight is not None: | |
weight_i = vertex_weight[indices_i] | |
weight_j = vertex_weight[indices_j] | |
else: | |
weight_i = weight_j = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device) | |
weight_i = weight_i[..., None] | |
weight_j = weight_j[..., None] | |
if edge_weight is not None: | |
weight_ij = edge_weight | |
else: | |
weight_ij = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device) | |
weight_ij = weight_ij[..., None] | |
# Extracts the rotation we need per term. | |
quaternion_i = quaternions[..., indices_i, :] | |
quaternion_j = quaternions[..., indices_j, :] | |
# Computes the energy. | |
deformed_ij = vertices_i_deformed - vertices_j_deformed | |
rotated_rest_ij = pytorch3d.transforms.quaternion_apply(quaternion_i, (vertices_i_rest - vertices_j_rest)) | |
energy_ij = weight_i * weight_ij * (deformed_ij - rotated_rest_ij) | |
deformed_ji = vertices_j_deformed - vertices_i_deformed | |
rotated_rest_ji = pytorch3d.transforms.quaternion_apply(quaternion_j, (vertices_j_rest - vertices_i_rest)) | |
energy_ji = weight_j * weight_ij * (deformed_ji - rotated_rest_ji) | |
energy_ij_squared = torch.sum(energy_ij ** 2, dim=-1) | |
energy_ji_squared = torch.sum(energy_ji ** 2, dim=-1) | |
if aggregate_loss: | |
average_energy_ij = torch.mean(energy_ij_squared, dim=-1) | |
average_energy_ji = torch.mean(energy_ji_squared, dim=-1) | |
return (average_energy_ij + average_energy_ji) / 2.0 | |
return torch.cat((energy_ij_squared, energy_ji_squared), dim=-1) | |
def arap_loss(vertices_rest_pose, vertices_deformed_pose, edges): | |
# squash batch dimensions | |
vertices_rest_pose_shape = list(vertices_rest_pose.shape) | |
vertices_deformed_pose_shape = list(vertices_deformed_pose.shape) | |
vertices_rest_pose = vertices_rest_pose.reshape([-1] + vertices_rest_pose_shape[-2:]) | |
vertices_deformed_pose = vertices_deformed_pose.reshape([-1] + vertices_deformed_pose_shape[-2:]) | |
# try: | |
quaternions = compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges) | |
# except RuntimeError: | |
# print('SVD did not converge') | |
# batch_size = vertices_rest_pose.shape[0] | |
# num_vertices = vertices_rest_pose.shape[-2] | |
# quaternions = pytorch3d.transforms.matrix_to_quaternion(pytorch3d.transforms.euler_angles_to_matrix(torch.zeros([batch_size, num_vertices, 3], device=vertices_rest_pose.device), 'XYZ')) | |
quaternions = quaternions.detach() | |
energy = arap_energy( | |
vertices_rest_pose, | |
vertices_deformed_pose, | |
quaternions, | |
edges, | |
aggregate_loss=True, | |
conformal_energy=False) | |
return energy.reshape(vertices_rest_pose_shape[:-2]) | |