File size: 12,356 Bytes
98a77e0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
# import pytorch3d
import torch
from einops import rearrange
from torch._C import device


def edges_to_sparse_incidence(edges, num_vertices):
    num_edges = edges.shape[0]
    row_indexes = torch.arange(num_edges, dtype=torch.long, device=edges.device).repeat_interleave(2)
    col_indexes = edges.reshape(-1)
    indexes = torch.stack([row_indexes, col_indexes])
    values = torch.FloatTensor([1, -1]).to(edges.device).repeat(num_edges)
    return torch.sparse.FloatTensor(indexes, values, torch.Size([num_edges, num_vertices]))


def compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat):
    """
    Adapted from:
    https://github.com/kzhou23/shape_pose_disent/blob/a8017c405892c98f52fa9775327172633290b1d8/arap.py#L76

    vertices_rest_pose: B x V x D
    vertices_deformed_pose: B x V x D
    incidence_mat: E x V
    
    """
    batch_size, num_vertices, dimensions = vertices_rest_pose.shape
    vertices = torch.cat((vertices_rest_pose, vertices_deformed_pose), dim=0)
    # 2B x V x D -> V x (D x 2B)
    vertices = rearrange(vertices, 'a v d -> v (d a)')
    # E x V . V x (D x 2B) - > E x (D x 2B)
    edges = torch.sparse.mm(incidence_mat, vertices)
    edges = rearrange(edges, 'e (d a) -> a e d', d=dimensions)                 
    rest_edges, deformed_edges = torch.split(edges, batch_size, dim=0)

    edges_outer = torch.matmul(rest_edges[:, :, :, None], deformed_edges[:, :, None, :])
    edges_outer = rearrange(edges_outer, 'b e d1 d2 -> e (b d1 d2)')

    abs_incidence_mat = incidence_mat.clone()
    abs_incidence_mat._values()[:] = torch.abs(abs_incidence_mat._values())
    
    # transposed S
    S = torch.sparse.mm(abs_incidence_mat.t(), edges_outer)
    S = rearrange(S, 'v (b d1 d2) -> b v d2 d1', v=num_vertices, b=batch_size, d1=dimensions, d2=dimensions)
    
    # SVD on gpu is extremely slow! https://github.com/pytorch/pytorch/pull/48436
    device = S.device
    U, _, V = torch.svd(S.cpu())
    U = U.to(device)
    V = V.to(device)

    det_sign = torch.det(torch.matmul(U, V.transpose(-2, -1)))
    U = torch.cat([U[..., :-1], U[..., -1:] * det_sign[..., None, None]], axis=-1)

    rotations = torch.matmul(U, V.transpose(-2, -1))

    return rotations


def compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges):
  """
  vertices_rest_pose: B x V x D
  vertices_deformed_pose: B x V x D
  edges: E x 2
  """
  num_vertices = vertices_rest_pose.shape[1]
  incidence_mat = edges_to_sparse_incidence(edges, num_vertices)
  rot = compute_svd_rotation(vertices_rest_pose, vertices_deformed_pose, incidence_mat)
  rot = pytorch3d.transforms.matrix_to_quaternion(rot)
  return rot


def quaternion_normalize(quaternion, eps=1e-12):
  """
  Adapted from tensorflow_graphics

  Normalizes a quaternion.

  Note:
    In the following, A1 to An are optional batch dimensions.

  Args:
    quaternion:  A tensor of shape `[A1, ..., An, 4]`, where the last dimension
      represents a quaternion.
    eps: A lower bound value for the norm that defaults to 1e-12.
    name: A name for this op that defaults to "quaternion_normalize".

  Returns:
    A N-D tensor of shape `[?, ..., ?, 1]` where the quaternion elements have
    been normalized.

  Raises:
    ValueError: If the shape of `quaternion` is not supported.
  """
  return l2_normalize(quaternion, dim=-1, epsilon=eps)


def l2_normalize(x, dim=-1, epsilon=1e-12):
    square_sum = torch.sum(x ** 2, dim=dim, keepdim=True)
    x_inv_norm = torch.rsqrt(torch.clamp(square_sum, min=epsilon))
    return x * x_inv_norm
    

def arap_energy(vertices_rest_pose,
           vertices_deformed_pose,
           quaternions,
           edges,
           vertex_weight=None,
           edge_weight=None,
           conformal_energy=True,
           aggregate_loss=True):
  """
  Adapted from tensorflow_graphics

  Estimates an As Conformal As Possible (ACAP) fitting energy.
  For a given mesh in rest pose, this function evaluates a variant of the ACAP
  [1] fitting energy for a batch of deformed meshes. The vertex weights and edge
  weights are defined on the rest pose.
  The method implemented here is similar to [2], but with an added free variable
    capturing a scale factor per vertex.
  [1]: Yusuke Yoshiyasu, Wan-Chun Ma, Eiichi Yoshida, and Fumio Kanehiro.
  "As-Conformal-As-Possible Surface Registration." Computer Graphics Forum. Vol.
  33. No. 5. 2014.</br>
  [2]: Olga Sorkine, and Marc Alexa.
  "As-rigid-as-possible surface modeling". Symposium on Geometry Processing.
  Vol. 4. 2007.
  Note:
    In the description of the arguments, V corresponds to
      the number of vertices in the mesh, and E to the number of edges in this
      mesh.
  Note:
    In the following, A1 to An are optional batch dimensions.
  Args:
    vertices_rest_pose: A tensor of shape `[V, 3]` containing the position of
      all the vertices of the mesh in rest pose.
    vertices_deformed_pose: A tensor of shape `[A1, ..., An, V, 3]` containing
      the position of all the vertices of the mesh in deformed pose.
    quaternions: A tensor of shape `[A1, ..., An, V, 4]` defining a rigid
      transformation to apply to each vertex of the rest pose. See Section 2
      from [1] for further details.
    edges: A tensor of shape `[E, 2]` defining indices of vertices that are
      connected by an edge.
    vertex_weight: An optional tensor of shape `[V]` defining the weight
      associated with each vertex. Defaults to a tensor of ones.
    edge_weight: A tensor of shape `[E]` defining the weight of edges. Common
      choices for these weights include uniform weighting, and cotangent
      weights. Defaults to a tensor of ones.
    conformal_energy: A `bool` indicating whether each vertex is associated with
      a scale factor or not. If this parameter is True, scaling information must
      be encoded in the norm of `quaternions`. If this parameter is False, this
      function implements the energy described in [2].
    aggregate_loss: A `bool` defining whether the returned loss should be an
      aggregate measure. When True, the mean squared error is returned. When
      False, returns two losses for every edge of the mesh.
    name: A name for this op. Defaults to "as_conformal_as_possible_energy".
  Returns:
    When aggregate_loss is `True`, returns a tensor of shape `[A1, ..., An]`
    containing the ACAP energies. When aggregate_loss is `False`, returns a
    tensor of shape `[A1, ..., An, 2*E]` containing each term of the summation
    described in the equation 7 of [2].
  Raises:
    ValueError: if the shape of `vertices_rest_pose`, `vertices_deformed_pose`,
    `quaternions`, `edges`, `vertex_weight`, or `edge_weight` is not supported.
  """
  # with tf.compat.v1.name_scope(name, "as_conformal_as_possible_energy", [
  #     vertices_rest_pose, vertices_deformed_pose, quaternions, edges,
  #     conformal_energy, vertex_weight, edge_weight
  # ]):
  # vertices_rest_pose = tf.convert_to_tensor(value=vertices_rest_pose)
  # vertices_deformed_pose = tf.convert_to_tensor(value=vertices_deformed_pose)
  # quaternions = tf.convert_to_tensor(value=quaternions)
  # edges = tf.convert_to_tensor(value=edges)
  # if vertex_weight is not None:
  #   vertex_weight = tf.convert_to_tensor(value=vertex_weight)
  # if edge_weight is not None:
  #   edge_weight = tf.convert_to_tensor(value=edge_weight)

  # shape.check_static(
  #     tensor=vertices_rest_pose,
  #     tensor_name="vertices_rest_pose",
  #     has_rank=2,
  #     has_dim_equals=(-1, 3))
  # shape.check_static(
  #     tensor=vertices_deformed_pose,
  #     tensor_name="vertices_deformed_pose",
  #     has_rank_greater_than=1,
  #     has_dim_equals=(-1, 3))
  # shape.check_static(
  #     tensor=quaternions,
  #     tensor_name="quaternions",
  #     has_rank_greater_than=1,
  #     has_dim_equals=(-1, 4))
  # shape.compare_batch_dimensions(
  #     tensors=(vertices_deformed_pose, quaternions),
  #     last_axes=(-3, -3),
  #     broadcast_compatible=False)
  # shape.check_static(
  #     tensor=edges, tensor_name="edges", has_rank=2, has_dim_equals=(-1, 2))
  # tensors_with_vertices = [vertices_rest_pose,
  #                           vertices_deformed_pose,
  #                           quaternions]
  # names_with_vertices = ["vertices_rest_pose",
  #                         "vertices_deformed_pose",
  #                         "quaternions"]
  # axes_with_vertices = [-2, -2, -2]
  # if vertex_weight is not None:
  #   shape.check_static(
  #       tensor=vertex_weight, tensor_name="vertex_weight", has_rank=1)
  #   tensors_with_vertices.append(vertex_weight)
  #   names_with_vertices.append("vertex_weight")
  #   axes_with_vertices.append(0)
  # shape.compare_dimensions(
  #     tensors=tensors_with_vertices,
  #     axes=axes_with_vertices,
  #     tensor_names=names_with_vertices)
  # if edge_weight is not None:
  #   shape.check_static(
  #       tensor=edge_weight, tensor_name="edge_weight", has_rank=1)
  #   shape.compare_dimensions(
  #       tensors=(edges, edge_weight),
  #       axes=(0, 0),
  #       tensor_names=("edges", "edge_weight"))

  if not conformal_energy:
    quaternions = quaternion_normalize(quaternions)
  # Extracts the indices of vertices.
  indices_i, indices_j = torch.unbind(edges, dim=-1)
  # Extracts the vertices we need per term.
  vertices_i_rest = vertices_rest_pose[..., indices_i, :]
  vertices_j_rest = vertices_rest_pose[..., indices_j, :]
  vertices_i_deformed = vertices_deformed_pose[..., indices_i, :]
  vertices_j_deformed = vertices_deformed_pose[..., indices_j, :]
  # Extracts the weights we need per term.
  weights_shape = vertices_i_rest.shape[-2]
  if vertex_weight is not None:
    weight_i = vertex_weight[indices_i]
    weight_j = vertex_weight[indices_j]
  else:
    weight_i = weight_j = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device)
  weight_i = weight_i[..., None]
  weight_j = weight_j[..., None]
  if edge_weight is not None:
    weight_ij = edge_weight
  else:
    weight_ij = torch.ones(weights_shape, dtype=vertices_rest_pose.dtype, device=vertices_rest_pose.device)
  weight_ij = weight_ij[..., None]
  # Extracts the rotation we need per term.
  quaternion_i = quaternions[..., indices_i, :]
  quaternion_j = quaternions[..., indices_j, :]
  # Computes the energy.
  deformed_ij = vertices_i_deformed - vertices_j_deformed
  rotated_rest_ij = pytorch3d.transforms.quaternion_apply(quaternion_i, (vertices_i_rest - vertices_j_rest))
  energy_ij = weight_i * weight_ij * (deformed_ij - rotated_rest_ij)
  deformed_ji = vertices_j_deformed - vertices_i_deformed
  rotated_rest_ji = pytorch3d.transforms.quaternion_apply(quaternion_j, (vertices_j_rest - vertices_i_rest))
  energy_ji = weight_j * weight_ij * (deformed_ji - rotated_rest_ji)
  energy_ij_squared = torch.sum(energy_ij ** 2, dim=-1)
  energy_ji_squared = torch.sum(energy_ji ** 2, dim=-1)
  if aggregate_loss:
    average_energy_ij = torch.mean(energy_ij_squared, dim=-1)
    average_energy_ji = torch.mean(energy_ji_squared, dim=-1)
    return (average_energy_ij + average_energy_ji) / 2.0
  return torch.cat((energy_ij_squared, energy_ji_squared), dim=-1)


def arap_loss(vertices_rest_pose, vertices_deformed_pose, edges):
    # squash batch dimensions
    vertices_rest_pose_shape = list(vertices_rest_pose.shape)
    vertices_deformed_pose_shape = list(vertices_deformed_pose.shape)
    vertices_rest_pose = vertices_rest_pose.reshape([-1] + vertices_rest_pose_shape[-2:])
    vertices_deformed_pose = vertices_deformed_pose.reshape([-1] + vertices_deformed_pose_shape[-2:])
    
    # try:
    quaternions = compute_rotation(vertices_rest_pose, vertices_deformed_pose, edges)
    # except RuntimeError:
    #   print('SVD did not converge')
    # batch_size = vertices_rest_pose.shape[0]
    # num_vertices = vertices_rest_pose.shape[-2]
    # quaternions = pytorch3d.transforms.matrix_to_quaternion(pytorch3d.transforms.euler_angles_to_matrix(torch.zeros([batch_size, num_vertices, 3], device=vertices_rest_pose.device), 'XYZ'))
    
    quaternions = quaternions.detach()

    energy = arap_energy(
      vertices_rest_pose,
      vertices_deformed_pose,
      quaternions,
      edges,
      aggregate_loss=True,
      conformal_energy=False)
    return energy.reshape(vertices_rest_pose_shape[:-2])