Spaces:
Running
on
L40S
Running
on
L40S
from __future__ import annotations | |
from typing import Any, Dict, Optional | |
import torch | |
import torch.nn.functional as F | |
from jaxtyping import Float, Integer | |
from torch import Tensor | |
from sf3d.box_uv_unwrap import box_projection_uv_unwrap | |
from sf3d.models.utils import dot | |
class Mesh: | |
def __init__( | |
self, v_pos: Float[Tensor, "Nv 3"], t_pos_idx: Integer[Tensor, "Nf 3"], **kwargs | |
) -> None: | |
self.v_pos: Float[Tensor, "Nv 3"] = v_pos | |
self.t_pos_idx: Integer[Tensor, "Nf 3"] = t_pos_idx | |
self._v_nrm: Optional[Float[Tensor, "Nv 3"]] = None | |
self._v_tng: Optional[Float[Tensor, "Nv 3"]] = None | |
self._v_tex: Optional[Float[Tensor, "Nt 3"]] = None | |
self._edges: Optional[Integer[Tensor, "Ne 2"]] = None | |
self.extras: Dict[str, Any] = {} | |
for k, v in kwargs.items(): | |
self.add_extra(k, v) | |
def add_extra(self, k, v) -> None: | |
self.extras[k] = v | |
def requires_grad(self): | |
return self.v_pos.requires_grad | |
def v_nrm(self): | |
if self._v_nrm is None: | |
self._v_nrm = self._compute_vertex_normal() | |
return self._v_nrm | |
def v_tng(self): | |
if self._v_tng is None: | |
self._v_tng = self._compute_vertex_tangent() | |
return self._v_tng | |
def v_tex(self): | |
if self._v_tex is None: | |
self.unwrap_uv() | |
return self._v_tex | |
def edges(self): | |
if self._edges is None: | |
self._edges = self._compute_edges() | |
return self._edges | |
def _compute_vertex_normal(self): | |
i0 = self.t_pos_idx[:, 0] | |
i1 = self.t_pos_idx[:, 1] | |
i2 = self.t_pos_idx[:, 2] | |
v0 = self.v_pos[i0, :] | |
v1 = self.v_pos[i1, :] | |
v2 = self.v_pos[i2, :] | |
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) | |
# Splat face normals to vertices | |
v_nrm = torch.zeros_like(self.v_pos) | |
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) | |
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) | |
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) | |
# Normalize, replace zero (degenerated) normals with some default value | |
v_nrm = torch.where( | |
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) | |
) | |
v_nrm = F.normalize(v_nrm, dim=1) | |
if torch.is_anomaly_enabled(): | |
assert torch.all(torch.isfinite(v_nrm)) | |
return v_nrm | |
def _compute_vertex_tangent(self): | |
vn_idx = [None] * 3 | |
pos = [None] * 3 | |
tex = [None] * 3 | |
for i in range(0, 3): | |
pos[i] = self.v_pos[self.t_pos_idx[:, i]] | |
tex[i] = self.v_tex[self.t_pos_idx[:, i]] | |
# t_nrm_idx is always the same as t_pos_idx | |
vn_idx[i] = self.t_pos_idx[:, i] | |
tangents = torch.zeros_like(self.v_nrm) | |
tansum = torch.zeros_like(self.v_nrm) | |
# Compute tangent space for each triangle | |
duv1 = tex[1] - tex[0] | |
duv2 = tex[2] - tex[0] | |
dpos1 = pos[1] - pos[0] | |
dpos2 = pos[2] - pos[0] | |
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2] | |
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1] | |
# Avoid division by zero for degenerated texture coordinates | |
denom_safe = denom.clip(1e-6) | |
tang = tng_nom / denom_safe | |
# Update all 3 vertices | |
for i in range(0, 3): | |
idx = vn_idx[i][:, None].repeat(1, 3) | |
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang | |
tansum.scatter_add_( | |
0, idx, torch.ones_like(tang) | |
) # tansum[n_i] = tansum[n_i] + 1 | |
# Also normalize it. Here we do not normalize the individual triangles first so larger area | |
# triangles influence the tangent space more | |
tangents = tangents / tansum | |
# Normalize and make sure tangent is perpendicular to normal | |
tangents = F.normalize(tangents, dim=1) | |
tangents = F.normalize(tangents - dot(tangents, self.v_nrm) * self.v_nrm) | |
if torch.is_anomaly_enabled(): | |
assert torch.all(torch.isfinite(tangents)) | |
return tangents | |
def unwrap_uv( | |
self, | |
island_padding: float = 0.02, | |
) -> Mesh: | |
uv, indices = box_projection_uv_unwrap( | |
self.v_pos, self.v_nrm, self.t_pos_idx, island_padding | |
) | |
# Do store per vertex UVs. | |
# This means we need to duplicate some vertices at the seams | |
individual_vertices = self.v_pos[self.t_pos_idx].reshape(-1, 3) | |
individual_faces = torch.arange( | |
individual_vertices.shape[0], | |
device=individual_vertices.device, | |
dtype=self.t_pos_idx.dtype, | |
).reshape(-1, 3) | |
uv_flat = uv[indices].reshape((-1, 2)) | |
# uv_flat[:, 1] = 1 - uv_flat[:, 1] | |
self.v_pos = individual_vertices | |
self.t_pos_idx = individual_faces | |
self._v_tex = uv_flat | |
self._v_nrm = self._compute_vertex_normal() | |
self._v_tng = self._compute_vertex_tangent() | |
def _compute_edges(self): | |
# Compute edges | |
edges = torch.cat( | |
[ | |
self.t_pos_idx[:, [0, 1]], | |
self.t_pos_idx[:, [1, 2]], | |
self.t_pos_idx[:, [2, 0]], | |
], | |
dim=0, | |
) | |
edges = edges.sort()[0] | |
edges = torch.unique(edges, dim=0) | |
return edges | |