Spaces:
Sleeping
Sleeping
File size: 12,356 Bytes
98a77e0 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 |
# import pytorch3d
import torch
from einops import rearrange
from torch._C import device
def edges_to_sparse_incidence(edges, num_vertices):
num_edges = edges.shape[0]
row_indexes = torch.arange(num_edges, dtype=torch.long, device=edges.device).repeat_interleave(2)
col_indexes = edges.reshape(-1)
indexes = torch.stack([row_indexes, col_indexes])
values = torch.FloatTensor([1, -1]).to(edges.device).repeat(num_edges)
return torch.sparse.FloatTensor(indexes, values, torch.Size([num_edges, num_vertices]))
def compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat):
"""
Adapted from:
https://github.com/kzhou23/shape_pose_disent/blob/a8017c405892c98f52fa9775327172633290b1d8/arap.py#L76
vertices_rest_pose: B x V x D
vertices_deformed_pose: B x V x D
incidence_mat: E x V
"""
batch_size, num_vertices, dimensions = vertices_rest_pose.shape
vertices = torch.cat((vertices_rest_pose, vertices_deformed_pose), dim=0)
# 2B x V x D -> V x (D x 2B)
vertices = rearrange(vertices, 'a v d -> v (d a)')
# E x V . V x (D x 2B) - > E x (D x 2B)
edges = torch.sparse.mm(incidence_mat, vertices)
edges = rearrange(edges, 'e (d a) -> a e d', d=dimensions)
rest_edges, deformed_edges = torch.split(edges, batch_size, dim=0)
edges_outer = torch.matmul(rest_edges[:, :, :, None], deformed_edges[:, :, None, :])
edges_outer = rearrange(edges_outer, 'b e d1 d2 -> e (b d1 d2)')
abs_incidence_mat = incidence_mat.clone()
abs_incidence_mat._values()[:] = torch.abs(abs_incidence_mat._values())
# transposed S
S = torch.sparse.mm(abs_incidence_mat.t(), edges_outer)
S = rearrange(S, 'v (b d1 d2) -> b v d2 d1', v=num_vertices, b=batch_size, d1=dimensions, d2=dimensions)
# SVD on gpu is extremely slow! https://github.com/pytorch/pytorch/pull/48436
device = S.device
U, _, V = torch.svd(S.cpu())
U = U.to(device)
V = V.to(device)
det_sign = torch.det(torch.matmul(U, V.transpose(-2, -1)))
U = torch.cat([U[..., :-1], U[..., -1:] * det_sign[..., None, None]], axis=-1)
rotations = torch.matmul(U, V.transpose(-2, -1))
return rotations
def compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges):
"""
vertices_rest_pose: B x V x D
vertices_deformed_pose: B x V x D
edges: E x 2
"""
num_vertices = vertices_rest_pose.shape[1]
incidence_mat = edges_to_sparse_incidence(edges, num_vertices)
rot = compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat)
rot = pytorch3d.transforms.matrix_to_quaternion(rot)
return rot
def quaternion_normalize(quaternion, eps=1e-12):
"""
Adapted from tensorflow_graphics
Normalizes a quaternion.
Note:
In the following, A1 to An are optional batch dimensions.
Args:
quaternion: A tensor of shape `[A1, ..., An, 4]`, where the last dimension
represents a quaternion.
eps: A lower bound value for the norm that defaults to 1e-12.
name: A name for this op that defaults to "quaternion_normalize".
Returns:
A N-D tensor of shape `[?, ..., ?, 1]` where the quaternion elements have
been normalized.
Raises:
ValueError: If the shape of `quaternion` is not supported.
"""
return l2_normalize(quaternion, dim=-1, epsilon=eps)
def l2_normalize(x, dim=-1, epsilon=1e-12):
square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
x_inv_norm = torch.rsqrt(torch.clamp(square_sum, min=epsilon))
return x * x_inv_norm
def arap_energy(vertices_rest_pose,
vertices_deformed_pose,
quaternions,
edges,
vertex_weight=None,
edge_weight=None,
conformal_energy=True,
aggregate_loss=True):
"""
Adapted from tensorflow_graphics
Estimates an As Conformal As Possible (ACAP) fitting energy.
For a given mesh in rest pose, this function evaluates a variant of the ACAP
[1] fitting energy for a batch of deformed meshes. The vertex weights and edge
weights are defined on the rest pose.
The method implemented here is similar to [2], but with an added free variable
capturing a scale factor per vertex.
[1]: Yusuke Yoshiyasu, Wan-Chun Ma, Eiichi Yoshida, and Fumio Kanehiro.
"As-Conformal-As-Possible Surface Registration." Computer Graphics Forum. Vol.
33. No. 5. 2014.</br>
[2]: Olga Sorkine, and Marc Alexa.
"As-rigid-as-possible surface modeling". Symposium on Geometry Processing.
Vol. 4. 2007.
Note:
In the description of the arguments, V corresponds to
the number of vertices in the mesh, and E to the number of edges in this
mesh.
Note:
In the following, A1 to An are optional batch dimensions.
Args:
vertices_rest_pose: A tensor of shape `[V, 3]` containing the position of
all the vertices of the mesh in rest pose.
vertices_deformed_pose: A tensor of shape `[A1, ..., An, V, 3]` containing
the position of all the vertices of the mesh in deformed pose.
quaternions: A tensor of shape `[A1, ..., An, V, 4]` defining a rigid
transformation to apply to each vertex of the rest pose. See Section 2
from [1] for further details.
edges: A tensor of shape `[E, 2]` defining indices of vertices that are
connected by an edge.
vertex_weight: An optional tensor of shape `[V]` defining the weight
associated with each vertex. Defaults to a tensor of ones.
edge_weight: A tensor of shape `[E]` defining the weight of edges. Common
choices for these weights include uniform weighting, and cotangent
weights. Defaults to a tensor of ones.
conformal_energy: A `bool` indicating whether each vertex is associated with
a scale factor or not. If this parameter is True, scaling information must
be encoded in the norm of `quaternions`. If this parameter is False, this
function implements the energy described in [2].
aggregate_loss: A `bool` defining whether the returned loss should be an
aggregate measure. When True, the mean squared error is returned. When
False, returns two losses for every edge of the mesh.
name: A name for this op. Defaults to "as_conformal_as_possible_energy".
Returns:
When aggregate_loss is `True`, returns a tensor of shape `[A1, ..., An]`
containing the ACAP energies. When aggregate_loss is `False`, returns a
tensor of shape `[A1, ..., An, 2*E]` containing each term of the summation
described in the equation 7 of [2].
Raises:
ValueError: if the shape of `vertices_rest_pose`, `vertices_deformed_pose`,
`quaternions`, `edges`, `vertex_weight`, or `edge_weight` is not supported.
"""
# with tf.compat.v1.name_scope(name, "as_conformal_as_possible_energy", [
# vertices_rest_pose, vertices_deformed_pose, quaternions, edges,
# conformal_energy, vertex_weight, edge_weight
# ]):
# vertices_rest_pose = tf.convert_to_tensor(value=vertices_rest_pose)
# vertices_deformed_pose = tf.convert_to_tensor(value=vertices_deformed_pose)
# quaternions = tf.convert_to_tensor(value=quaternions)
# edges = tf.convert_to_tensor(value=edges)
# if vertex_weight is not None:
# vertex_weight = tf.convert_to_tensor(value=vertex_weight)
# if edge_weight is not None:
# edge_weight = tf.convert_to_tensor(value=edge_weight)
# shape.check_static(
# tensor=vertices_rest_pose,
# tensor_name="vertices_rest_pose",
# has_rank=2,
# has_dim_equals=(-1, 3))
# shape.check_static(
# tensor=vertices_deformed_pose,
# tensor_name="vertices_deformed_pose",
# has_rank_greater_than=1,
# has_dim_equals=(-1, 3))
# shape.check_static(
# tensor=quaternions,
# tensor_name="quaternions",
# has_rank_greater_than=1,
# has_dim_equals=(-1, 4))
# shape.compare_batch_dimensions(
# tensors=(vertices_deformed_pose, quaternions),
# last_axes=(-3, -3),
# broadcast_compatible=False)
# shape.check_static(
# tensor=edges, tensor_name="edges", has_rank=2, has_dim_equals=(-1, 2))
# tensors_with_vertices = [vertices_rest_pose,
# vertices_deformed_pose,
# quaternions]
# names_with_vertices = ["vertices_rest_pose",
# "vertices_deformed_pose",
# "quaternions"]
# axes_with_vertices = [-2, -2, -2]
# if vertex_weight is not None:
# shape.check_static(
# tensor=vertex_weight, tensor_name="vertex_weight", has_rank=1)
# tensors_with_vertices.append(vertex_weight)
# names_with_vertices.append("vertex_weight")
# axes_with_vertices.append(0)
# shape.compare_dimensions(
# tensors=tensors_with_vertices,
# axes=axes_with_vertices,
# tensor_names=names_with_vertices)
# if edge_weight is not None:
# shape.check_static(
# tensor=edge_weight, tensor_name="edge_weight", has_rank=1)
# shape.compare_dimensions(
# tensors=(edges, edge_weight),
# axes=(0, 0),
# tensor_names=("edges", "edge_weight"))
if not conformal_energy:
quaternions = quaternion_normalize(quaternions)
# Extracts the indices of vertices.
indices_i, indices_j = torch.unbind(edges, dim=-1)
# Extracts the vertices we need per term.
vertices_i_rest = vertices_rest_pose[..., indices_i, :]
vertices_j_rest = vertices_rest_pose[..., indices_j, :]
vertices_i_deformed = vertices_deformed_pose[..., indices_i, :]
vertices_j_deformed = vertices_deformed_pose[..., indices_j, :]
# Extracts the weights we need per term.
weights_shape = vertices_i_rest.shape[-2]
if vertex_weight is not None:
weight_i = vertex_weight[indices_i]
weight_j = vertex_weight[indices_j]
else:
weight_i = weight_j = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device)
weight_i = weight_i[..., None]
weight_j = weight_j[..., None]
if edge_weight is not None:
weight_ij = edge_weight
else:
weight_ij = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device)
weight_ij = weight_ij[..., None]
# Extracts the rotation we need per term.
quaternion_i = quaternions[..., indices_i, :]
quaternion_j = quaternions[..., indices_j, :]
# Computes the energy.
deformed_ij = vertices_i_deformed - vertices_j_deformed
rotated_rest_ij = pytorch3d.transforms.quaternion_apply(quaternion_i, (vertices_i_rest - vertices_j_rest))
energy_ij = weight_i * weight_ij * (deformed_ij - rotated_rest_ij)
deformed_ji = vertices_j_deformed - vertices_i_deformed
rotated_rest_ji = pytorch3d.transforms.quaternion_apply(quaternion_j, (vertices_j_rest - vertices_i_rest))
energy_ji = weight_j * weight_ij * (deformed_ji - rotated_rest_ji)
energy_ij_squared = torch.sum(energy_ij ** 2, dim=-1)
energy_ji_squared = torch.sum(energy_ji ** 2, dim=-1)
if aggregate_loss:
average_energy_ij = torch.mean(energy_ij_squared, dim=-1)
average_energy_ji = torch.mean(energy_ji_squared, dim=-1)
return (average_energy_ij + average_energy_ji) / 2.0
return torch.cat((energy_ij_squared, energy_ji_squared), dim=-1)
def arap_loss(vertices_rest_pose, vertices_deformed_pose, edges):
# squash batch dimensions
vertices_rest_pose_shape = list(vertices_rest_pose.shape)
vertices_deformed_pose_shape = list(vertices_deformed_pose.shape)
vertices_rest_pose = vertices_rest_pose.reshape([-1] + vertices_rest_pose_shape[-2:])
vertices_deformed_pose = vertices_deformed_pose.reshape([-1] + vertices_deformed_pose_shape[-2:])
# try:
quaternions = compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges)
# except RuntimeError:
# print('SVD did not converge')
# batch_size = vertices_rest_pose.shape[0]
# num_vertices = vertices_rest_pose.shape[-2]
# quaternions = pytorch3d.transforms.matrix_to_quaternion(pytorch3d.transforms.euler_angles_to_matrix(torch.zeros([batch_size, num_vertices, 3], device=vertices_rest_pose.device), 'XYZ'))
quaternions = quaternions.detach()
energy = arap_energy(
vertices_rest_pose,
vertices_deformed_pose,
quaternions,
edges,
aggregate_loss=True,
conformal_energy=False)
return energy.reshape(vertices_rest_pose_shape[:-2])
|