Spaces:
Running
on
L40S
Running
on
L40S
from typing import Optional, Tuple | |
import numpy as np | |
import torch | |
import torch.nn as nn | |
from jaxtyping import Float, Integer | |
from torch import Tensor | |
from .mesh import Mesh | |
class IsosurfaceHelper(nn.Module): | |
points_range: Tuple[float, float] = (0, 1) | |
def grid_vertices(self) -> Float[Tensor, "N 3"]: | |
raise NotImplementedError | |
def requires_instance_per_batch(self) -> bool: | |
return False | |
class MarchingTetrahedraHelper(IsosurfaceHelper): | |
def __init__(self, resolution: int, tets_path: str): | |
super().__init__() | |
self.resolution = resolution | |
self.tets_path = tets_path | |
self.triangle_table: Float[Tensor, "..."] | |
self.register_buffer( | |
"triangle_table", | |
torch.as_tensor( | |
[ | |
[-1, -1, -1, -1, -1, -1], | |
[1, 0, 2, -1, -1, -1], | |
[4, 0, 3, -1, -1, -1], | |
[1, 4, 2, 1, 3, 4], | |
[3, 1, 5, -1, -1, -1], | |
[2, 3, 0, 2, 5, 3], | |
[1, 4, 0, 1, 5, 4], | |
[4, 2, 5, -1, -1, -1], | |
[4, 5, 2, -1, -1, -1], | |
[4, 1, 0, 4, 5, 1], | |
[3, 2, 0, 3, 5, 2], | |
[1, 3, 5, -1, -1, -1], | |
[4, 1, 2, 4, 3, 1], | |
[3, 0, 4, -1, -1, -1], | |
[2, 0, 1, -1, -1, -1], | |
[-1, -1, -1, -1, -1, -1], | |
], | |
dtype=torch.long, | |
), | |
persistent=False, | |
) | |
self.num_triangles_table: Integer[Tensor, "..."] | |
self.register_buffer( | |
"num_triangles_table", | |
torch.as_tensor( | |
[0, 1, 1, 2, 1, 2, 2, 1, 1, 2, 2, 1, 2, 1, 1, 0], dtype=torch.long | |
), | |
persistent=False, | |
) | |
self.base_tet_edges: Integer[Tensor, "..."] | |
self.register_buffer( | |
"base_tet_edges", | |
torch.as_tensor([0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], dtype=torch.long), | |
persistent=False, | |
) | |
tets = np.load(self.tets_path) | |
self._grid_vertices: Float[Tensor, "..."] | |
self.register_buffer( | |
"_grid_vertices", | |
torch.from_numpy(tets["vertices"]).float(), | |
persistent=False, | |
) | |
self.indices: Integer[Tensor, "..."] | |
self.register_buffer( | |
"indices", torch.from_numpy(tets["indices"]).long(), persistent=False | |
) | |
self._all_edges: Optional[Integer[Tensor, "Ne 2"]] = None | |
center_indices, boundary_indices = self.get_center_boundary_index( | |
self._grid_vertices | |
) | |
self.center_indices: Integer[Tensor, "..."] | |
self.register_buffer("center_indices", center_indices, persistent=False) | |
self.boundary_indices: Integer[Tensor, "..."] | |
self.register_buffer("boundary_indices", boundary_indices, persistent=False) | |
def get_center_boundary_index(self, verts): | |
magn = torch.sum(verts**2, dim=-1) | |
center_idx = torch.argmin(magn) | |
boundary_neg = verts == verts.max() | |
boundary_pos = verts == verts.min() | |
boundary = torch.bitwise_or(boundary_pos, boundary_neg) | |
boundary = torch.sum(boundary.float(), dim=-1) | |
boundary_idx = torch.nonzero(boundary) | |
return center_idx, boundary_idx.squeeze(dim=-1) | |
def normalize_grid_deformation( | |
self, grid_vertex_offsets: Float[Tensor, "Nv 3"] | |
) -> Float[Tensor, "Nv 3"]: | |
return ( | |
(self.points_range[1] - self.points_range[0]) | |
/ self.resolution # half tet size is approximately 1 / self.resolution | |
* torch.tanh(grid_vertex_offsets) | |
) # FIXME: hard-coded activation | |
def grid_vertices(self) -> Float[Tensor, "Nv 3"]: | |
return self._grid_vertices | |
def all_edges(self) -> Integer[Tensor, "Ne 2"]: | |
if self._all_edges is None: | |
# compute edges on GPU, or it would be VERY SLOW (basically due to the unique operation) | |
edges = torch.tensor( | |
[0, 1, 0, 2, 0, 3, 1, 2, 1, 3, 2, 3], | |
dtype=torch.long, | |
device=self.indices.device, | |
) | |
_all_edges = self.indices[:, edges].reshape(-1, 2) | |
_all_edges_sorted = torch.sort(_all_edges, dim=1)[0] | |
_all_edges = torch.unique(_all_edges_sorted, dim=0) | |
self._all_edges = _all_edges | |
return self._all_edges | |
def sort_edges(self, edges_ex2): | |
with torch.no_grad(): | |
order = (edges_ex2[:, 0] > edges_ex2[:, 1]).long() | |
order = order.unsqueeze(dim=1) | |
a = torch.gather(input=edges_ex2, index=order, dim=1) | |
b = torch.gather(input=edges_ex2, index=1 - order, dim=1) | |
return torch.stack([a, b], -1) | |
def _forward(self, pos_nx3, sdf_n, tet_fx4): | |
with torch.no_grad(): | |
occ_n = sdf_n > 0 | |
occ_fx4 = occ_n[tet_fx4.reshape(-1)].reshape(-1, 4) | |
occ_sum = torch.sum(occ_fx4, -1) | |
valid_tets = (occ_sum > 0) & (occ_sum < 4) | |
occ_sum = occ_sum[valid_tets] | |
# find all vertices | |
all_edges = tet_fx4[valid_tets][:, self.base_tet_edges].reshape(-1, 2) | |
all_edges = self.sort_edges(all_edges) | |
unique_edges, idx_map = torch.unique(all_edges, dim=0, return_inverse=True) | |
unique_edges = unique_edges.long() | |
mask_edges = occ_n[unique_edges.reshape(-1)].reshape(-1, 2).sum(-1) == 1 | |
mapping = ( | |
torch.ones( | |
(unique_edges.shape[0]), dtype=torch.long, device=pos_nx3.device | |
) | |
* -1 | |
) | |
mapping[mask_edges] = torch.arange( | |
mask_edges.sum(), dtype=torch.long, device=pos_nx3.device | |
) | |
idx_map = mapping[idx_map] # map edges to verts | |
interp_v = unique_edges[mask_edges] | |
edges_to_interp = pos_nx3[interp_v.reshape(-1)].reshape(-1, 2, 3) | |
edges_to_interp_sdf = sdf_n[interp_v.reshape(-1)].reshape(-1, 2, 1) | |
edges_to_interp_sdf[:, -1] *= -1 | |
denominator = edges_to_interp_sdf.sum(1, keepdim=True) | |
edges_to_interp_sdf = torch.flip(edges_to_interp_sdf, [1]) / denominator | |
verts = (edges_to_interp * edges_to_interp_sdf).sum(1) | |
idx_map = idx_map.reshape(-1, 6) | |
v_id = torch.pow(2, torch.arange(4, dtype=torch.long, device=pos_nx3.device)) | |
tetindex = (occ_fx4[valid_tets] * v_id.unsqueeze(0)).sum(-1) | |
num_triangles = self.num_triangles_table[tetindex] | |
# Generate triangle indices | |
faces = torch.cat( | |
( | |
torch.gather( | |
input=idx_map[num_triangles == 1], | |
dim=1, | |
index=self.triangle_table[tetindex[num_triangles == 1]][:, :3], | |
).reshape(-1, 3), | |
torch.gather( | |
input=idx_map[num_triangles == 2], | |
dim=1, | |
index=self.triangle_table[tetindex[num_triangles == 2]][:, :6], | |
).reshape(-1, 3), | |
), | |
dim=0, | |
) | |
return verts, faces | |
def forward( | |
self, | |
level: Float[Tensor, "N3 1"], | |
deformation: Optional[Float[Tensor, "N3 3"]] = None, | |
) -> Mesh: | |
if deformation is not None: | |
grid_vertices = self.grid_vertices + self.normalize_grid_deformation( | |
deformation | |
) | |
else: | |
grid_vertices = self.grid_vertices | |
v_pos, t_pos_idx = self._forward(grid_vertices, level, self.indices) | |
mesh = Mesh( | |
v_pos=v_pos, | |
t_pos_idx=t_pos_idx, | |
# extras | |
grid_vertices=grid_vertices, | |
tet_edges=self.all_edges, | |
grid_level=level, | |
grid_deformation=deformation, | |
) | |
return mesh | |