Spaces:
Runtime error
Runtime error
import torch | |
# code from https://raw.githubusercontent.com/yufu-wang/aves/main/optimization/loss_arap.py | |
class Arap_Loss(): | |
''' | |
Pytorch implementaion: As-rigid-as-possible loss class | |
''' | |
def __init__(self, meshes, device='cpu', vertex_w=None): | |
with torch.no_grad(): # new nadine | |
self.device = device | |
self.bn = len(meshes) | |
# get lapacian cotangent matrix | |
L = self.get_laplacian_cot(meshes) | |
self.wij = L.values().clone() | |
self.wij[self.wij<0] = 0. | |
# get ajacency matrix | |
V = meshes.num_verts_per_mesh().sum() | |
edges_packed = meshes.edges_packed() | |
e0, e1 = edges_packed.unbind(1) | |
idx01 = torch.stack([e0, e1], dim=1) | |
idx10 = torch.stack([e1, e0], dim=1) | |
idx = torch.cat([idx01, idx10], dim=0).t() | |
ones = torch.ones(idx.shape[1], dtype=torch.float32).to(device) | |
A = torch.sparse.FloatTensor(idx, ones, (V, V)) | |
self.deg = torch.sparse.sum(A, dim=1).to_dense().long() | |
self.idx = self.sort_idx(idx) | |
# get edges of default mesh | |
self.eij = self.get_edges(meshes) | |
# get per vertex regularization strength | |
self.vertex_w = vertex_w | |
def __call__(self, new_meshes): | |
new_meshes._compute_packed() | |
optimal_R = self.step_1(new_meshes) | |
arap_loss = self.step_2(optimal_R, new_meshes) | |
return arap_loss | |
def step_1(self, new_meshes): | |
bn = self.bn | |
eij = self.eij.view(bn, -1, 3).cpu() | |
with torch.no_grad(): | |
eij_ = self.get_edges(new_meshes) | |
eij_ = eij_.view(bn, -1, 3).cpu() | |
wij = self.wij.view(bn, -1).cpu() | |
deg_1 = self.deg.view(bn, -1)[0].cpu() # assuming same topology | |
S = torch.zeros([bn, len(deg_1), 3, 3]) | |
for i in range(len(deg_1)): | |
start, end = deg_1[:i].sum(), deg_1[:i+1].sum() | |
P = eij[:, start : end] | |
P_ = eij_[:, start : end] | |
D = wij[:, start : end] | |
D = torch.diag_embed(D) | |
S[:, i] = P.transpose(-2,-1) @ D @ P_ | |
S = S.view(-1, 3, 3) | |
u, _, v = torch.svd(S) | |
R = v @ u.transpose(-2, -1) | |
det = torch.det(R) | |
u[det<0, :, -1] *= -1 | |
R = v @ u.transpose(-2, -1) | |
R = R.to(self.device) | |
return R | |
def step_2(self, R, new_meshes): | |
R = torch.repeat_interleave(R, self.deg, dim=0) | |
Reij = R @ self.eij.unsqueeze(2) | |
Reij = Reij.squeeze() | |
eij_ = self.get_edges(new_meshes) | |
arap_loss = self.wij * (eij_ - Reij).norm(dim=1) | |
if self.vertex_w is not None: | |
vertex_w = torch.repeat_interleave(self.vertex_w, self.deg, dim=0) | |
arap_loss = arap_loss * vertex_w | |
arap_loss = arap_loss.sum() / self.bn | |
return arap_loss | |
def get_edges(self, meshes): | |
verts_packed = meshes.verts_packed() | |
vi = torch.repeat_interleave(verts_packed, self.deg, dim=0) | |
vj = verts_packed[self.idx[1]] | |
eij = vi - vj | |
return eij | |
def sort_idx(self, idx): | |
_, order = (idx[0] + idx[1]*1e-6).sort() | |
return idx[:, order] | |
def get_laplacian_cot(self, meshes): | |
''' | |
Routine modified from : | |
pytorch3d/loss/mesh_laplacian_smoothing.py | |
''' | |
verts_packed = meshes.verts_packed() | |
faces_packed = meshes.faces_packed() | |
V, F = verts_packed.shape[0], faces_packed.shape[0] | |
face_verts = verts_packed[faces_packed] | |
v0, v1, v2 = face_verts[:,0], face_verts[:,1], face_verts[:,2] | |
A = (v1-v2).norm(dim=1) | |
B = (v0-v2).norm(dim=1) | |
C = (v0-v1).norm(dim=1) | |
s = 0.5 * (A+B+C) | |
area = (s * (s - A) * (s - B) * (s - C)).clamp_(min=1e-12).sqrt() | |
A2, B2, C2 = A * A, B * B, C * C | |
cota = (B2 + C2 - A2) / area | |
cotb = (A2 + C2 - B2) / area | |
cotc = (A2 + B2 - C2) / area | |
cot = torch.stack([cota, cotb, cotc], dim=1) | |
cot /= 4.0 | |
ii = faces_packed[:, [1,2,0]] | |
jj = faces_packed[:, [2,0,1]] | |
idx = torch.stack([ii, jj], dim=0).view(2, F*3) | |
L = torch.sparse.FloatTensor(idx, cot.view(-1), (V, V)) | |
L += L.t() | |
L = L.coalesce() | |
L /= 2.0 # normalized according to arap paper | |
return L | |