Spaces:
Running
on
T4
Running
on
T4
# Cages code used from https://github.com/yifita/deep_cage | |
import torch | |
import numpy as np | |
import trimesh | |
def deform_with_MVC(cage, cage_deformed, cage_face, query, verbose=False): | |
""" | |
cage (B,C,3) | |
cage_deformed (B,C,3) | |
cage_face (B,F,3) int64 | |
query (B,Q,3) | |
""" | |
weights, weights_unnormed = mean_value_coordinates_3D(query, cage, cage_face, verbose=True) | |
# weights = weights.detach() | |
deformed = torch.sum(weights.unsqueeze(-1)*cage_deformed.unsqueeze(1), dim=2) | |
if verbose: | |
return deformed, weights, weights_unnormed | |
return deformed | |
def loadInitCage(template): | |
init_cage_V, init_cage_F = read_trimesh(template) | |
init_cage_V = torch.from_numpy(init_cage_V[:,:3].astype(np.float32)).unsqueeze(0)*2.0 | |
init_cage_F = torch.from_numpy(init_cage_F[:,:3].astype(np.int64)).unsqueeze(0) | |
return init_cage_V, init_cage_F | |
def read_trimesh(path): | |
mesh = trimesh.load(path) | |
return mesh.vertices, mesh.faces | |
# util functions from pytorch_points | |
PI = 3.1415927 | |
def normalize_to_box(input): | |
""" | |
normalize point cloud to unit bounding box | |
center = (max - min)/2 | |
scale = max(abs(x)) | |
input: pc [N, P, dim] or [P, dim] | |
output: pc, centroid, furthest_distance | |
""" | |
if len(input.shape) == 2: | |
axis = 0 | |
P = input.shape[0] | |
D = input.shape[1] | |
elif len(input.shape) == 3: | |
axis = 1 | |
P = input.shape[1] | |
D = input.shape[2] | |
if isinstance(input, np.ndarray): | |
maxP = np.amax(input, axis=axis, keepdims=True) | |
minP = np.amin(input, axis=axis, keepdims=True) | |
centroid = (maxP+minP)/2 | |
input = input - centroid | |
furthest_distance = np.amax(np.abs(input), axis=(axis, -1), keepdims=True) | |
input = input / furthest_distance | |
elif isinstance(input, torch.Tensor): | |
maxP = torch.max(input, dim=axis, keepdim=True)[0] | |
minP = torch.min(input, dim=axis, keepdim=True)[0] | |
centroid = (maxP+minP)/2 | |
input = input - centroid | |
in_shape = list(input.shape[:axis])+[P*D] | |
furthest_distance = torch.max(torch.abs(input).view(in_shape), dim=axis, keepdim=True)[0] | |
furthest_distance = furthest_distance.unsqueeze(-1) | |
input = input / furthest_distance | |
return input, centroid, furthest_distance | |
def normalize(tensor, dim=-1): | |
"""normalize tensor in specified dimension""" | |
return torch.nn.functional.normalize(tensor, p=2, dim=dim, eps=1e-12, out=None) | |
def check_values(tensor): | |
"""return true if tensor doesn't contain NaN or Inf""" | |
return not (torch.any(torch.isnan(tensor)).item() or torch.any(torch.isinf(tensor)).item()) | |
class ScatterAdd(torch.autograd.Function): | |
def forward(ctx, src, idx, dim, out_size, fill=0.0): | |
out = torch.full(out_size, fill, device=src.device, dtype=src.dtype) | |
ctx.save_for_backward(idx) | |
out.scatter_add_(dim, idx, src) | |
ctx.mark_non_differentiable(idx) | |
ctx.dim = dim | |
return out | |
def backward(ctx, ograd): | |
idx, = ctx.saved_tensors | |
grad = torch.gather(ograd, ctx.dim, idx) | |
return grad, None, None, None, None | |
_scatter_add = ScatterAdd.apply | |
def scatter_add(src, idx, dim, out_size=None, fill=0.0): | |
if out_size is None: | |
out_size = list(src.size()) | |
dim_size = idx.max().item()+1 | |
out_size[dim] = dim_size | |
return _scatter_add(src, idx, dim, out_size, fill) | |
def mean_value_coordinates_3D(query, vertices, faces, verbose=False): | |
""" | |
Tao Ju et.al. MVC for 3D triangle meshes | |
params: | |
query (B,P,3) | |
vertices (B,N,3) | |
faces (B,F,3) | |
return: | |
wj (B,P,N) | |
""" | |
B, F, _ = faces.shape | |
_, P, _ = query.shape | |
_, N, _ = vertices.shape | |
# u_i = p_i - x (B,P,N,3) | |
uj = vertices.unsqueeze(1) - query.unsqueeze(2) | |
# \|u_i\| (B,P,N,1) | |
dj = torch.norm(uj, dim=-1, p=2, keepdim=True) | |
uj = normalize(uj, dim=-1) | |
# gather triangle B,P,F,3,3 | |
ui = torch.gather(uj.unsqueeze(2).expand(-1,-1,F,-1,-1), | |
3, | |
faces.unsqueeze(1).unsqueeze(-1).expand(-1,P,-1,-1,3)) | |
# li = \|u_{i+1}-u_{i-1}\| (B,P,F,3) | |
li = torch.norm(ui[:,:,:,[1, 2, 0],:] - ui[:, :, :,[2, 0, 1],:], dim=-1, p=2) | |
eps = 2e-5 | |
li = torch.where(li>=2, li-(li.detach()-(2-eps)), li) | |
li = torch.where(li<=-2, li-(li.detach()+(2-eps)), li) | |
# asin(x) is inf at +/-1 | |
# θi = 2arcsin[li/2] (B,P,F,3) | |
theta_i = 2*torch.asin(li/2) | |
assert(check_values(theta_i)) | |
# B,P,F,1 | |
h = torch.sum(theta_i, dim=-1, keepdim=True)/2 | |
# wi← sin[θi]d{i−1}d{i+1} | |
# (B,P,F,3) ci ← (2sin[h]sin[h−θi])/(sin[θ_{i+1}]sin[θ_{i−1}])−1 | |
ci = 2*torch.sin(h)*torch.sin(h-theta_i)/(torch.sin(theta_i[:,:,:,[1, 2, 0]])*torch.sin(theta_i[:,:,:,[2, 0, 1]]))-1 | |
# NOTE: because of floating point ci can be slightly larger than 1, causing problem with sqrt(1-ci^2) | |
# NOTE: sqrt(x)' is nan for x=0, hence use eps | |
eps = 1e-5 | |
ci = torch.where(ci>=1, ci-(ci.detach()-(1-eps)), ci) | |
ci = torch.where(ci<=-1, ci-(ci.detach()+(1-eps)), ci) | |
# si← sign[det[u1,u2,u3]]sqrt(1-ci^2) | |
# (B,P,F)*(B,P,F,3) | |
si = torch.sign(torch.det(ui)).unsqueeze(-1)*torch.sqrt(1-ci**2) # sqrt gradient nan for 0 | |
assert(check_values(si)) | |
# (B,P,F,3) | |
di = torch.gather(dj.unsqueeze(2).squeeze(-1).expand(-1,-1,F,-1), 3, | |
faces.unsqueeze(1).expand(-1,P,-1,-1)) | |
assert(check_values(di)) | |
# if si.requires_grad: | |
# vertices.register_hook(save_grad("mvc/dv")) | |
# li.register_hook(save_grad("mvc/dli")) | |
# theta_i.register_hook(save_grad("mvc/dtheta")) | |
# ci.register_hook(save_grad("mvc/dci")) | |
# si.register_hook(save_grad("mvc/dsi")) | |
# di.register_hook(save_grad("mvc/ddi")) | |
# wi← (θi −c[i+1]θ[i−1] −c[i−1]θ[i+1])/(disin[θi+1]s[i−1]) | |
# B,P,F,3 | |
# CHECK is there a 2* in the denominator | |
wi = (theta_i-ci[:,:,:,[1,2,0]]*theta_i[:,:,:,[2,0,1]]-ci[:,:,:,[2,0,1]]*theta_i[:,:,:,[1,2,0]])/(di*torch.sin(theta_i[:,:,:,[1,2,0]])*si[:,:,:,[2,0,1]]) | |
# if ∃i,|si| ≤ ε, set wi to 0. coplaner with T but outside | |
# ignore coplaner outside triangle | |
# alternative check | |
# (B,F,3,3) | |
# triangle_points = torch.gather(vertices.unsqueeze(1).expand(-1,F,-1,-1), 2, faces.unsqueeze(-1).expand(-1,-1,-1,3)) | |
# # (B,P,F,3), (B,1,F,3) -> (B,P,F,1) | |
# determinant = dot_product(triangle_points[:,:,:,0].unsqueeze(1)-query.unsqueeze(2), | |
# torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0], | |
# triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1).unsqueeze(1), dim=-1, keepdim=True).detach() | |
# # (B,P,F,1) | |
# sqrdist = determinant*determinant / (4 * sqrNorm(torch.cross(triangle_points[:,:,:,1]-triangle_points[:,:,:,0], triangle_points[:,:,:,2]-triangle_points[:,:,:,0], dim=-1), keepdim=True)) | |
wi = torch.where(torch.any(torch.abs(si) <= 1e-5, keepdim=True, dim=-1), torch.zeros_like(wi), wi) | |
# wi = torch.where(sqrdist <= 1e-5, torch.zeros_like(wi), wi) | |
# if π −h < ε, x lies on t, use 2D barycentric coordinates | |
# inside triangle | |
inside_triangle = (PI-h).squeeze(-1)<1e-4 | |
# set all F for this P to zero | |
wi = torch.where(torch.any(inside_triangle, dim=-1, keepdim=True).unsqueeze(-1), torch.zeros_like(wi), wi) | |
# CHECK is it di https://www.cse.wustl.edu/~taoju/research/meanvalue.pdf or li http://citeseerx.ist.psu.edu/viewdoc/download?doi=10.1.1.516.1856&rep=rep1&type=pdf | |
wi = torch.where(inside_triangle.unsqueeze(-1).expand(-1,-1,-1,wi.shape[-1]), torch.sin(theta_i)*di[:,:,:,[2,0,1]]*di[:,:,:,[1,2,0]], wi) | |
# sum over all faces face -> vertex (B,P,F*3) -> (B,P,N) | |
wj = scatter_add(wi.reshape(B,P,-1).contiguous(), faces.unsqueeze(1).expand(-1,P,-1,-1).reshape(B,P,-1), 2, out_size=(B,P,N)) | |
# close to vertex (B,P,N) | |
close_to_point = dj.squeeze(-1) < 1e-8 | |
# set all F for this P to zero | |
wj = torch.where(torch.any(close_to_point, dim=-1, keepdim=True), torch.zeros_like(wj), wj) | |
wj = torch.where(close_to_point, torch.ones_like(wj), wj) | |
# (B,P,1) | |
sumWj = torch.sum(wj, dim=-1, keepdim=True) | |
sumWj = torch.where(sumWj==0, torch.ones_like(sumWj), sumWj) | |
wj_normalised = wj / sumWj | |
# if wj.requires_grad: | |
# saved_variables["mvc/wi"] = wi | |
# wi.register_hook(save_grad("mvc/dwi")) | |
# wj.register_hook(save_grad("mvc/dwj")) | |
if verbose: | |
return wj_normalised, wi | |
else: | |
return wj_normalised | |