Nadine Rueegg
initial commit with code and data
753fd9a
import torch
# code from https://raw.githubusercontent.com/yufu-wang/aves/main/optimization/loss_arap.py
class Arap_Loss():
'''
Pytorch implementaion: As-rigid-as-possible loss class
'''
def __init__(self, meshes, device='cpu', vertex_w=None):
with torch.no_grad(): # new nadine
self.device = device
self.bn = len(meshes)
# get lapacian cotangent matrix
L = self.get_laplacian_cot(meshes)
self.wij = L.values().clone()
self.wij[self.wij<0] = 0.
# get ajacency matrix
V = meshes.num_verts_per_mesh().sum()
edges_packed = meshes.edges_packed()
e0, e1 = edges_packed.unbind(1)
idx01 = torch.stack([e0, e1], dim=1)
idx10 = torch.stack([e1, e0], dim=1)
idx = torch.cat([idx01, idx10], dim=0).t()
ones = torch.ones(idx.shape[1], dtype=torch.float32).to(device)
A = torch.sparse.FloatTensor(idx, ones, (V, V))
self.deg = torch.sparse.sum(A, dim=1).to_dense().long()
self.idx = self.sort_idx(idx)
# get edges of default mesh
self.eij = self.get_edges(meshes)
# get per vertex regularization strength
self.vertex_w = vertex_w
def __call__(self, new_meshes):
new_meshes._compute_packed()
optimal_R = self.step_1(new_meshes)
arap_loss = self.step_2(optimal_R, new_meshes)
return arap_loss
def step_1(self, new_meshes):
bn = self.bn
eij = self.eij.view(bn, -1, 3).cpu()
with torch.no_grad():
eij_ = self.get_edges(new_meshes)
eij_ = eij_.view(bn, -1, 3).cpu()
wij = self.wij.view(bn, -1).cpu()
deg_1 = self.deg.view(bn, -1)[0].cpu() # assuming same topology
S = torch.zeros([bn, len(deg_1), 3, 3])
for i in range(len(deg_1)):
start, end = deg_1[:i].sum(), deg_1[:i+1].sum()
P = eij[:, start : end]
P_ = eij_[:, start : end]
D = wij[:, start : end]
D = torch.diag_embed(D)
S[:, i] = P.transpose(-2,-1) @ D @ P_
S = S.view(-1, 3, 3)
u, _, v = torch.svd(S)
R = v @ u.transpose(-2, -1)
det = torch.det(R)
u[det<0, :, -1] *= -1
R = v @ u.transpose(-2, -1)
R = R.to(self.device)
return R
def step_2(self, R, new_meshes):
R = torch.repeat_interleave(R, self.deg, dim=0)
Reij = R @ self.eij.unsqueeze(2)
Reij = Reij.squeeze()
eij_ = self.get_edges(new_meshes)
arap_loss = self.wij * (eij_ - Reij).norm(dim=1)
if self.vertex_w is not None:
vertex_w = torch.repeat_interleave(self.vertex_w, self.deg, dim=0)
arap_loss = arap_loss * vertex_w
arap_loss = arap_loss.sum() / self.bn
return arap_loss
def get_edges(self, meshes):
verts_packed = meshes.verts_packed()
vi = torch.repeat_interleave(verts_packed, self.deg, dim=0)
vj = verts_packed[self.idx[1]]
eij = vi - vj
return eij
def sort_idx(self, idx):
_, order = (idx[0] + idx[1]*1e-6).sort()
return idx[:, order]
def get_laplacian_cot(self, meshes):
'''
Routine modified from :
pytorch3d/loss/mesh_laplacian_smoothing.py
'''
verts_packed = meshes.verts_packed()
faces_packed = meshes.faces_packed()
V, F = verts_packed.shape[0], faces_packed.shape[0]
face_verts = verts_packed[faces_packed]
v0, v1, v2 = face_verts[:,0], face_verts[:,1], face_verts[:,2]
A = (v1-v2).norm(dim=1)
B = (v0-v2).norm(dim=1)
C = (v0-v1).norm(dim=1)
s = 0.5 * (A+B+C)
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt()
A2, B2, C2 = A * A, B * B, C * C
cota = (B2 + C2 - A2) / area
cotb = (A2 + C2 - B2) / area
cotc = (A2 + B2 - C2) / area
cot = torch.stack([cota, cotb, cotc], dim=1)
cot /= 4.0
ii = faces_packed[:, [1,2,0]]
jj = faces_packed[:, [2,0,1]]
idx = torch.stack([ii, jj], dim=0).view(2, F*3)
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V))
L += L.t()
L = L.coalesce()
L /= 2.0 # normalized according to arap paper
return L