Spaces:
Running
on
L4
Running
on
L4
import math | |
from typing import Tuple | |
import torch | |
import torch.nn.functional as F | |
from jaxtyping import Bool, Float, Integer, Int, Num | |
from torch import Tensor | |
def tri_winding(tri: Float[Tensor, "*B 3 2"]) -> Float[Tensor, "*B 3 3"]: | |
# One pad for determinant | |
tri_sq = F.pad(tri, (0, 1), "constant", 1.0) | |
det_tri = torch.det(tri_sq) | |
tri_rev = torch.cat( | |
(tri_sq[..., 0:1, :], tri_sq[..., 2:3, :], tri_sq[..., 1:2, :]), -2 | |
) | |
tri_sq[det_tri < 0] = tri_rev[det_tri < 0] | |
return tri_sq | |
def triangle_intersection_2d( | |
t1: Float[Tensor, "*B 3 2"], | |
t2: Float[Tensor, "*B 3 2"], | |
eps=1e-12, | |
) -> Float[Tensor, "*B"]: # noqa: F821 | |
"""Returns True if triangles collide, False otherwise""" | |
def chk_edge(x: Float[Tensor, "*B 3 3"]) -> Bool[Tensor, "*B"]: # noqa: F821 | |
logdetx = torch.logdet(x.double()) | |
if eps is None: | |
return ~torch.isfinite(logdetx) | |
return ~(torch.isfinite(logdetx) & (logdetx > math.log(eps))) | |
t1s = tri_winding(t1) | |
t2s = tri_winding(t2) | |
# Assume the triangles do not collide in the begging | |
ret = torch.zeros(t1.shape[0], dtype=torch.bool, device=t1.device) | |
for i in range(3): | |
edge = torch.roll(t1s, i, dims=1)[:, :2, :] | |
# Check if all points of triangle 2 lay on the external side of edge E. | |
# If this is the case the triangle do not collide | |
upd = ( | |
chk_edge(torch.cat((edge, t2s[:, 0:1]), 1)) | |
& chk_edge(torch.cat((edge, t2s[:, 1:2]), 1)) | |
& chk_edge(torch.cat((edge, t2s[:, 2:3]), 1)) | |
) | |
# Here no collision is still True due to inversion | |
ret = ret | upd | |
for i in range(3): | |
edge = torch.roll(t2s, i, dims=1)[:, :2, :] | |
upd = ( | |
chk_edge(torch.cat((edge, t1s[:, 0:1]), 1)) | |
& chk_edge(torch.cat((edge, t1s[:, 1:2]), 1)) | |
& chk_edge(torch.cat((edge, t1s[:, 2:3]), 1)) | |
) | |
# Here no collision is still True due to inversion | |
ret = ret | upd | |
return ~ret # Do the inversion | |
def dot(x, y, dim=-1): | |
return torch.sum(x * y, dim, keepdim=True) | |
def compute_vertex_normal(v_pos, t_pos_idx): | |
i0 = t_pos_idx[:, 0] | |
i1 = t_pos_idx[:, 1] | |
i2 = t_pos_idx[:, 2] | |
v0 = v_pos[i0, :] | |
v1 = v_pos[i1, :] | |
v2 = v_pos[i2, :] | |
face_normals = torch.cross(v1 - v0, v2 - v0, dim=-1) | |
# Splat face normals to vertices | |
v_nrm = torch.zeros_like(v_pos) | |
v_nrm.scatter_add_(0, i0[:, None].repeat(1, 3), face_normals) | |
v_nrm.scatter_add_(0, i1[:, None].repeat(1, 3), face_normals) | |
v_nrm.scatter_add_(0, i2[:, None].repeat(1, 3), face_normals) | |
# Normalize, replace zero (degenerated) normals with some default value | |
v_nrm = torch.where( | |
dot(v_nrm, v_nrm) > 1e-20, v_nrm, torch.as_tensor([0.0, 0.0, 1.0]).to(v_nrm) | |
) | |
v_nrm = F.normalize(v_nrm, dim=1) | |
if torch.is_anomaly_enabled(): | |
assert torch.all(torch.isfinite(v_nrm)) | |
return v_nrm | |
def _box_assign_vertex_to_cube_face( | |
vertex_positions: Float[Tensor, "Nv 3"], | |
vertex_normals: Float[Tensor, "Nv 3"], | |
triangle_idxs: Integer[Tensor, "Nf 3"], | |
bbox: Float[Tensor, "2 3"], | |
) -> Tuple[Float[Tensor, "Nf 3 2"], Integer[Tensor, "Nf 3"]]: | |
# Test to not have a scaled model to fit the space better | |
# bbox_min = bbox[:1].mean(-1, keepdim=True) | |
# bbox_max = bbox[1:].mean(-1, keepdim=True) | |
# v_pos_normalized = (vertex_positions - bbox_min) / (bbox_max - bbox_min) | |
# Create a [0, 1] normalized vertex position | |
v_pos_normalized = (vertex_positions - bbox[:1]) / (bbox[1:] - bbox[:1]) | |
# And to [-1, 1] | |
v_pos_normalized = 2.0 * v_pos_normalized - 1.0 | |
# Get all vertex positions for each triangle | |
# Now how do we define to which face the triangle belongs? Mean face pos? Max vertex pos? | |
v0 = v_pos_normalized[triangle_idxs[:, 0]] | |
v1 = v_pos_normalized[triangle_idxs[:, 1]] | |
v2 = v_pos_normalized[triangle_idxs[:, 2]] | |
tri_stack = torch.stack([v0, v1, v2], dim=1) | |
vn0 = vertex_normals[triangle_idxs[:, 0]] | |
vn1 = vertex_normals[triangle_idxs[:, 1]] | |
vn2 = vertex_normals[triangle_idxs[:, 2]] | |
tri_stack_nrm = torch.stack([vn0, vn1, vn2], dim=1) | |
# Just average the normals per face | |
face_normal = F.normalize(torch.sum(tri_stack_nrm, 1), eps=1e-6, dim=-1) | |
# Now decide based on the face normal in which box map we project | |
# abs_x, abs_y, abs_z = tri_stack_nrm.abs().unbind(-1) | |
abs_x, abs_y, abs_z = tri_stack.abs().unbind(-1) | |
axis = torch.tensor( | |
[ | |
[1, 0, 0], # 0 | |
[-1, 0, 0], # 1 | |
[0, 1, 0], # 2 | |
[0, -1, 0], # 3 | |
[0, 0, 1], # 4 | |
[0, 0, -1], # 5 | |
], | |
device=face_normal.device, | |
dtype=face_normal.dtype, | |
) | |
face_normal_axis = (face_normal[:, None] * axis[None]).sum(-1) | |
index = face_normal_axis.argmax(-1) | |
max_axis, uc, vc = ( | |
torch.ones_like(abs_x), | |
torch.zeros_like(tri_stack[..., :1]), | |
torch.zeros_like(tri_stack[..., :1]), | |
) | |
mask_pos_x = index == 0 | |
max_axis[mask_pos_x] = abs_x[mask_pos_x] | |
uc[mask_pos_x] = tri_stack[mask_pos_x][..., 1:2] | |
vc[mask_pos_x] = -tri_stack[mask_pos_x][..., -1:] | |
mask_neg_x = index == 1 | |
max_axis[mask_neg_x] = abs_x[mask_neg_x] | |
uc[mask_neg_x] = tri_stack[mask_neg_x][..., 1:2] | |
vc[mask_neg_x] = -tri_stack[mask_neg_x][..., -1:] | |
mask_pos_y = index == 2 | |
max_axis[mask_pos_y] = abs_y[mask_pos_y] | |
uc[mask_pos_y] = tri_stack[mask_pos_y][..., 0:1] | |
vc[mask_pos_y] = -tri_stack[mask_pos_y][..., -1:] | |
mask_neg_y = index == 3 | |
max_axis[mask_neg_y] = abs_y[mask_neg_y] | |
uc[mask_neg_y] = tri_stack[mask_neg_y][..., 0:1] | |
vc[mask_neg_y] = -tri_stack[mask_neg_y][..., -1:] | |
mask_pos_z = index == 4 | |
max_axis[mask_pos_z] = abs_z[mask_pos_z] | |
uc[mask_pos_z] = tri_stack[mask_pos_z][..., 0:1] | |
vc[mask_pos_z] = tri_stack[mask_pos_z][..., 1:2] | |
mask_neg_z = index == 5 | |
max_axis[mask_neg_z] = abs_z[mask_neg_z] | |
uc[mask_neg_z] = tri_stack[mask_neg_z][..., 0:1] | |
vc[mask_neg_z] = -tri_stack[mask_neg_z][..., 1:2] | |
# UC from [-1, 1] to [0, 1] | |
max_dim_div = max_axis.max(dim=0, keepdims=True).values | |
uc = ((uc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1) | |
vc = ((vc[..., 0] / max_dim_div + 1.0) * 0.5).clip(0, 1) | |
uv = torch.stack([uc, vc], dim=-1) | |
return uv, index | |
def _assign_faces_uv_to_atlas_index( | |
vertex_positions: Float[Tensor, "Nv 3"], | |
triangle_idxs: Integer[Tensor, "Nf 3"], | |
face_uv: Float[Tensor, "Nf 3 2"], | |
face_index: Integer[Tensor, "Nf 3"], | |
) -> Integer[Tensor, "Nf"]: # noqa: F821 | |
triangle_pos = vertex_positions[triangle_idxs] | |
# We need to do perform 3 overlap checks. | |
# The first set is placed in the upper two thirds of the UV atlas. | |
# Conceptually, this is the direct visible surfaces from the each cube side | |
# The second set is placed in the lower thirds and the left half of the UV atlas. | |
# This is the first set of occluded surfaces. They will also be saved in the projected fashion | |
# The third pass finds all non assigned faces. They will be placed in the bottom right half of | |
# the UV atlas in scattered fashion. | |
assign_idx = face_index.clone() | |
for overlap_step in range(3): | |
overlapping_indicator = torch.zeros_like(assign_idx, dtype=torch.bool) | |
for i in range(overlap_step * 6, (overlap_step + 1) * 6): | |
mask = assign_idx == i | |
if not mask.any(): | |
continue | |
# Get all elements belonging to the projection face | |
uv_triangle = face_uv[mask] | |
cur_triangle_pos = triangle_pos[mask] | |
# Find the center of the uv coordinates | |
center_uv = uv_triangle.mean(dim=1, keepdim=True) | |
# And also the radius of the triangle | |
uv_triangle_radius = (uv_triangle - center_uv).norm(dim=-1).max(-1).values | |
potentially_overlapping_mask = ( | |
# Find all close triangles | |
(center_uv[None, ...] - center_uv[:, None]).norm(dim=-1) | |
# Do not select the same element by offseting with an large valued identity matrix | |
+ torch.eye( | |
uv_triangle.shape[0], | |
device=uv_triangle.device, | |
dtype=uv_triangle.dtype, | |
).unsqueeze(-1) | |
* 1000 | |
) | |
# Mark all potentially overlapping triangles to reduce the number of triangle intersection tests | |
potentially_overlapping_mask = ( | |
potentially_overlapping_mask | |
<= (uv_triangle_radius.view(-1, 1, 1) * 3.0) | |
).squeeze(-1) | |
overlap_coords = torch.stack(torch.where(potentially_overlapping_mask), -1) | |
# Only unique triangles (A|B and B|A should be the same) | |
f = torch.min(overlap_coords, dim=-1).values | |
s = torch.max(overlap_coords, dim=-1).values | |
overlap_coords = torch.unique(torch.stack([f, s], dim=1), dim=0) | |
first, second = overlap_coords.unbind(-1) | |
# Get the triangles | |
tri_1 = uv_triangle[first] | |
tri_2 = uv_triangle[second] | |
# Perform the actual set with the reduced number of potentially overlapping triangles | |
its = triangle_intersection_2d(tri_1, tri_2, eps=1e-6) | |
# So we now need to detect which triangles are the occluded ones. | |
# We always assume the first to be the visible one (the others should move) | |
# In the previous step we use a lexigraphical sort to get the unique pairs | |
# In this we use a sort based on the orthographic projection | |
ax = 0 if i < 2 else 1 if i < 4 else 2 | |
use_max = i % 2 == 1 | |
tri1_c = cur_triangle_pos[first].mean(dim=1) | |
tri2_c = cur_triangle_pos[second].mean(dim=1) | |
mark_first = ( | |
(tri1_c[..., ax] > tri2_c[..., ax]) | |
if use_max | |
else (tri1_c[..., ax] < tri2_c[..., ax]) | |
) | |
first[mark_first] = second[mark_first] | |
# Lastly the same index can be tested multiple times. | |
# If one marks it as overlapping we keep it marked as such. | |
# We do this by testing if it has been marked at least once. | |
unique_idx, rev_idx = torch.unique(first, return_inverse=True) | |
add = torch.zeros_like(unique_idx, dtype=torch.float32) | |
add.index_add_(0, rev_idx, its.float()) | |
its_mask = add > 0 | |
# And fill it in the overlapping indicator | |
idx = torch.where(mask)[0][unique_idx] | |
overlapping_indicator[idx] = its_mask | |
# Move the index to the overlap regions (shift by 6) | |
assign_idx[overlapping_indicator] += 6 | |
# We do not care about the correct face placement after the first 2 slices | |
max_idx = 6 * 2 | |
return assign_idx.clamp(0, max_idx) | |
def _find_slice_offset_and_scale( | |
index: Integer[Tensor, "Nf"], # noqa: F821 | |
) -> Tuple[ | |
Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"], Float[Tensor, "Nf"] # noqa: F821 | |
]: # noqa: F821 | |
# 6 due to the 6 cube faces | |
off = 1 / 3 | |
dupl_off = 1 / 6 | |
# Here, we need to decide how to pack the textures in the case of overlap | |
def x_offset_calc(x, i): | |
offset_calc = i // 6 | |
# Initial coordinates - just 3x2 grid | |
if offset_calc == 0: | |
return off * x | |
else: | |
# Smaller 3x2 grid plus eventual shift to right for | |
# second overlap | |
return dupl_off * x + min(offset_calc - 1, 1) * 0.5 | |
def y_offset_calc(x, i): | |
offset_calc = i // 6 | |
# Initial coordinates - just a 3x2 grid | |
if offset_calc == 0: | |
return off * x | |
else: | |
# Smaller coordinates in the lowest row | |
return dupl_off * x + off * 2 | |
offset_x = torch.zeros_like(index, dtype=torch.float32) | |
offset_y = torch.zeros_like(index, dtype=torch.float32) | |
offset_x_vals = [0, 1, 2, 0, 1, 2] | |
offset_y_vals = [0, 0, 0, 1, 1, 1] | |
for i in range(index.max().item() + 1): | |
mask = index == i | |
if not mask.any(): | |
continue | |
offset_x[mask] = x_offset_calc(offset_x_vals[i % 6], i) | |
offset_y[mask] = y_offset_calc(offset_y_vals[i % 6], i) | |
div_x = torch.full_like(index, 6 // 2, dtype=torch.float32) | |
# All overlap elements are saved in half scale | |
div_x[index >= 6] = 6 | |
div_y = div_x.clone() # Same for y | |
# Except for the random overlaps | |
div_x[index >= 12] = 2 | |
# But the random overlaps are saved in a large block in the lower thirds | |
div_y[index >= 12] = 3 | |
return offset_x, offset_y, div_x, div_y | |
def rotation_flip_matrix_2d( | |
rad: float, flip_x: bool = False, flip_y: bool = False | |
) -> Float[Tensor, "2 2"]: | |
cos = math.cos(rad) | |
sin = math.sin(rad) | |
rot_mat = torch.tensor([[cos, -sin], [sin, cos]], dtype=torch.float32) | |
flip_mat = torch.tensor( | |
[ | |
[-1 if flip_x else 1, 0], | |
[0, -1 if flip_y else 1], | |
], | |
dtype=torch.float32, | |
) | |
return flip_mat @ rot_mat | |
def calculate_tangents( | |
vertex_positions: Float[Tensor, "Nv 3"], | |
vertex_normals: Float[Tensor, "Nv 3"], | |
triangle_idxs: Integer[Tensor, "Nf 3"], | |
face_uv: Float[Tensor, "Nf 3 2"], | |
) -> Float[Tensor, "Nf 3 4"]: # noqa: F821 | |
vn_idx = [None] * 3 | |
pos = [None] * 3 | |
tex = face_uv.unbind(1) | |
for i in range(0, 3): | |
pos[i] = vertex_positions[triangle_idxs[:, i]] | |
# t_nrm_idx is always the same as t_pos_idx | |
vn_idx[i] = triangle_idxs[:, i] | |
tangents = torch.zeros_like(vertex_normals) | |
tansum = torch.zeros_like(vertex_normals) | |
# Compute tangent space for each triangle | |
duv1 = tex[1] - tex[0] | |
duv2 = tex[2] - tex[0] | |
dpos1 = pos[1] - pos[0] | |
dpos2 = pos[2] - pos[0] | |
tng_nom = dpos1 * duv2[..., 1:2] - dpos2 * duv1[..., 1:2] | |
denom = duv1[..., 0:1] * duv2[..., 1:2] - duv1[..., 1:2] * duv2[..., 0:1] | |
# Avoid division by zero for degenerated texture coordinates | |
denom_safe = denom.clip(1e-6) | |
tang = tng_nom / denom_safe | |
# Update all 3 vertices | |
for i in range(0, 3): | |
idx = vn_idx[i][:, None].repeat(1, 3) | |
tangents.scatter_add_(0, idx, tang) # tangents[n_i] = tangents[n_i] + tang | |
tansum.scatter_add_( | |
0, idx, torch.ones_like(tang) | |
) # tansum[n_i] = tansum[n_i] + 1 | |
# Also normalize it. Here we do not normalize the individual triangles first so larger area | |
# triangles influence the tangent space more | |
tangents = tangents / tansum | |
# Normalize and make sure tangent is perpendicular to normal | |
tangents = F.normalize(tangents, dim=1) | |
tangents = F.normalize(tangents - dot(tangents, vertex_normals) * vertex_normals) | |
return tangents | |
def _rotate_uv_slices_consistent_space( | |
vertex_positions: Float[Tensor, "Nv 3"], | |
vertex_normals: Float[Tensor, "Nv 3"], | |
triangle_idxs: Integer[Tensor, "Nf 3"], | |
uv: Float[Tensor, "Nf 3 2"], | |
index: Integer[Tensor, "Nf"], # noqa: F821 | |
): | |
tangents = calculate_tangents(vertex_positions, vertex_normals, triangle_idxs, uv) | |
pos_stack = torch.stack( | |
[ | |
-vertex_positions[..., 1], | |
vertex_positions[..., 0], | |
torch.zeros_like(vertex_positions[..., 0]), | |
], | |
dim=-1, | |
) | |
expected_tangents = F.normalize( | |
torch.linalg.cross( | |
vertex_normals, torch.linalg.cross(pos_stack, vertex_normals) | |
), | |
-1, | |
) | |
actual_tangents = tangents[triangle_idxs] | |
expected_tangents = expected_tangents[triangle_idxs] | |
def rotation_matrix_2d(theta): | |
c, s = torch.cos(theta), torch.sin(theta) | |
return torch.tensor([[c, -s], [s, c]]) | |
# Now find the rotation | |
index_mod = index % 6 # Shouldn't happen. Just for safety | |
for i in range(6): | |
mask = index_mod == i | |
if not mask.any(): | |
continue | |
actual_mean_tangent = actual_tangents[mask].mean(dim=(0, 1)) | |
expected_mean_tangent = expected_tangents[mask].mean(dim=(0, 1)) | |
dot_product = torch.dot(actual_mean_tangent, expected_mean_tangent) | |
cross_product = ( | |
actual_mean_tangent[0] * expected_mean_tangent[1] | |
- actual_mean_tangent[1] * expected_mean_tangent[0] | |
) | |
angle = torch.atan2(cross_product, dot_product) | |
rot_matrix = rotation_matrix_2d(angle).to(mask.device) | |
# Center the uv coordinate to be in the range of -1 to 1 and 0 centered | |
uv_cur = uv[mask] * 2 - 1 # Center it first | |
# Rotate it | |
uv[mask] = torch.einsum("ij,nfj->nfi", rot_matrix, uv_cur) | |
# Rescale uv[mask] to be within the 0-1 range | |
uv[mask] = (uv[mask] - uv[mask].min()) / (uv[mask].max() - uv[mask].min()) | |
return uv | |
def _handle_slice_uvs( | |
uv: Float[Tensor, "Nf 3 2"], | |
index: Integer[Tensor, "Nf"], # noqa: F821 | |
island_padding: float, | |
max_index: int = 6 * 2, | |
) -> Float[Tensor, "Nf 3 2"]: # noqa: F821 | |
uc, vc = uv.unbind(-1) | |
# Get the second slice (The first overlap) | |
index_filter = [index == i for i in range(6, max_index)] | |
# Normalize them to always fully fill the atlas patch | |
for i, fi in enumerate(index_filter): | |
if fi.sum() > 0: | |
# Scale the slice but only up to a factor of 2 | |
# This keeps the texture resolution with the first slice in line (Half space in UV) | |
uc[fi] = (uc[fi] - uc[fi].min()) / (uc[fi].max() - uc[fi].min()).clip(0.5) | |
vc[fi] = (vc[fi] - vc[fi].min()) / (vc[fi].max() - vc[fi].min()).clip(0.5) | |
uc_padded = (uc * (1 - 2 * island_padding) + island_padding).clip(0, 1) | |
vc_padded = (vc * (1 - 2 * island_padding) + island_padding).clip(0, 1) | |
return torch.stack([uc_padded, vc_padded], dim=-1) | |
def _handle_remaining_uvs( | |
uv: Float[Tensor, "Nf 3 2"], | |
index: Integer[Tensor, "Nf"], # noqa: F821 | |
island_padding: float, | |
) -> Float[Tensor, "Nf 3 2"]: | |
uc, vc = uv.unbind(-1) | |
# Get all remaining elements | |
remaining_filter = index >= 6 * 2 | |
squares_left = remaining_filter.sum() | |
if squares_left == 0: | |
return uv | |
uc = uc[remaining_filter] | |
vc = vc[remaining_filter] | |
# Or remaining triangles are distributed in a rectangle | |
# The rectangle takes 0.5 of the entire uv space in width and 1/3 in height | |
ratio = 0.5 * (1 / 3) # 1.5 | |
# sqrt(744/(0.5*(1/3))) | |
mult = math.sqrt(squares_left / ratio) | |
num_square_width = int(math.ceil(0.5 * mult)) | |
num_square_height = int(math.ceil(squares_left / num_square_width)) | |
width = 1 / num_square_width | |
height = 1 / num_square_height | |
# The idea is again to keep the texture resolution consistent with the first slice | |
# This only occupys half the region in the texture chart but the scaling on the squares | |
# assumes full coverage. | |
clip_val = min(width, height) * 1.5 | |
# Now normalize the UVs with taking into account the maximum scaling | |
uc = (uc - uc.min(dim=1, keepdim=True).values) / ( | |
uc.amax(dim=1, keepdim=True) - uc.amin(dim=1, keepdim=True) | |
).clip(clip_val) | |
vc = (vc - vc.min(dim=1, keepdim=True).values) / ( | |
vc.amax(dim=1, keepdim=True) - vc.amin(dim=1, keepdim=True) | |
).clip(clip_val) | |
# Add a small padding | |
uc = ( | |
uc * (1 - island_padding * num_square_width * 0.5) | |
+ island_padding * num_square_width * 0.25 | |
).clip(0, 1) | |
vc = ( | |
vc * (1 - island_padding * num_square_height * 0.5) | |
+ island_padding * num_square_height * 0.25 | |
).clip(0, 1) | |
uc = uc * width | |
vc = vc * height | |
# And calculate offsets for each element | |
idx = torch.arange(uc.shape[0], device=uc.device, dtype=torch.int32) | |
x_idx = idx % num_square_width | |
y_idx = idx // num_square_width | |
# And move each triangle to its own spot | |
uc = uc + x_idx[:, None] * width | |
vc = vc + y_idx[:, None] * height | |
uc = (uc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1) | |
vc = (vc * (1 - 2 * island_padding * 0.5) + island_padding * 0.5).clip(0, 1) | |
uv[remaining_filter] = torch.stack([uc, vc], dim=-1) | |
return uv | |
def _distribute_individual_uvs_in_atlas( | |
face_uv: Float[Tensor, "Nf 3 2"], | |
assigned_faces: Integer[Tensor, "Nf"], # noqa: F821 | |
offset_x: Float[Tensor, "Nf"], # noqa: F821 | |
offset_y: Float[Tensor, "Nf"], # noqa: F821 | |
div_x: Float[Tensor, "Nf"], # noqa: F821 | |
div_y: Float[Tensor, "Nf"], # noqa: F821 | |
island_padding: float, | |
): | |
# Place the slice first | |
placed_uv = _handle_slice_uvs(face_uv, assigned_faces, island_padding) | |
# Then handle the remaining overlap elements | |
placed_uv = _handle_remaining_uvs(placed_uv, assigned_faces, island_padding) | |
uc, vc = placed_uv.unbind(-1) | |
uc = uc / div_x[:, None] + offset_x[:, None] | |
vc = vc / div_y[:, None] + offset_y[:, None] | |
uv = torch.stack([uc, vc], dim=-1).view(-1, 2) | |
return uv | |
def _get_unique_face_uv( | |
uv: Float[Tensor, "Nf 3 2"], | |
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821 | |
unique_uv, unique_idx = torch.unique(uv, return_inverse=True, dim=0) | |
# And add the face to uv index mapping | |
vtex_idx = unique_idx.view(-1, 3) | |
return unique_uv, vtex_idx | |
def _align_mesh_with_main_axis( | |
vertex_positions: Float[Tensor, "Nv 3"], vertex_normals: Float[Tensor, "Nv 3"] | |
) -> Tuple[Float[Tensor, "Nv 3"], Float[Tensor, "Nv 3"]]: | |
# Use pca to find the 2 main axis (third is derived by cross product) | |
# Set the random seed so it's repeatable | |
torch.manual_seed(0) | |
_, _, v = torch.pca_lowrank(vertex_positions, q=2) | |
main_axis, seconday_axis = v[:, 0], v[:, 1] | |
main_axis: Float[Tensor, "3"] = F.normalize(main_axis, eps=1e-6, dim=-1) | |
# Orthogonalize the second axis | |
seconday_axis: Float[Tensor, "3"] = F.normalize( | |
seconday_axis - dot(seconday_axis, main_axis) * main_axis, eps=1e-6, dim=-1 | |
) | |
# Create perpendicular third axis | |
third_axis: Float[Tensor, "3"] = F.normalize( | |
torch.cross(main_axis, seconday_axis), dim=-1, eps=1e-6 | |
) | |
# Check to which canonical axis each aligns | |
main_axis_max_idx = main_axis.abs().argmax().item() | |
seconday_axis_max_idx = seconday_axis.abs().argmax().item() | |
third_axis_max_idx = third_axis.abs().argmax().item() | |
# Now sort the axes based on the argmax so they align with thecanonoical axes | |
# If two axes have the same argmax move one of them | |
all_possible_axis = {0, 1, 2} | |
cur_index = 1 | |
while len(set([main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx])) != 3: | |
# Find missing axis | |
missing_axis = all_possible_axis - set( | |
[main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx] | |
) | |
missing_axis = missing_axis.pop() | |
# Just assign it to third axis as it had the smallest contribution to the | |
# overall shape | |
if cur_index == 1: | |
third_axis_max_idx = missing_axis | |
elif cur_index == 2: | |
seconday_axis_max_idx = missing_axis | |
else: | |
raise ValueError("Could not find 3 unique axis") | |
cur_index += 1 | |
if len({main_axis_max_idx, seconday_axis_max_idx, third_axis_max_idx}) != 3: | |
raise ValueError("Could not find 3 unique axis") | |
axes = [None] * 3 | |
axes[main_axis_max_idx] = main_axis | |
axes[seconday_axis_max_idx] = seconday_axis | |
axes[third_axis_max_idx] = third_axis | |
# Create rotation matrix from the individual axes | |
rot_mat = torch.stack(axes, dim=1).T | |
# Now rotate the vertex positions and vertex normals so the mesh aligns with the main axis | |
vertex_positions = torch.einsum("ij,nj->ni", rot_mat, vertex_positions) | |
vertex_normals = torch.einsum("ij,nj->ni", rot_mat, vertex_normals) | |
return vertex_positions, vertex_normals | |
def box_projection_uv_unwrap( | |
vertex_positions: Float[Tensor, "Nv 3"], | |
vertex_normals: Float[Tensor, "Nv 3"], | |
triangle_idxs: Integer[Tensor, "Nf 3"], | |
island_padding: float, | |
) -> Tuple[Float[Tensor, "Utex 3"], Integer[Tensor, "Nf"]]: # noqa: F821 | |
# Align the mesh with main axis directions first | |
# vertex_positions, vertex_normals = _align_mesh_with_main_axis( | |
# vertex_positions, vertex_normals | |
# ) | |
bbox: Float[Tensor, "2 3"] = torch.stack( | |
[vertex_positions.min(dim=0).values, vertex_positions.max(dim=0).values], dim=0 | |
) | |
# First decide in which cube face the triangle is placed | |
face_uv, face_index = _box_assign_vertex_to_cube_face( | |
vertex_positions, vertex_normals, triangle_idxs, bbox | |
) | |
# Rotate the UV islands in a way that they align with the radial z tangent space | |
face_uv = _rotate_uv_slices_consistent_space( | |
vertex_positions, vertex_normals, triangle_idxs, face_uv, face_index | |
) | |
# Then find where where the face is placed in the atlas. | |
# This has to detect potential overlaps | |
assigned_atlas_index = _assign_faces_uv_to_atlas_index( | |
vertex_positions, triangle_idxs, face_uv, face_index | |
) | |
# Then figure out the final place in the atlas based on the assignment | |
offset_x, offset_y, div_x, div_y = _find_slice_offset_and_scale( | |
assigned_atlas_index | |
) | |
# Next distribute the faces in the uv atlas | |
placed_uv = _distribute_individual_uvs_in_atlas( | |
face_uv, assigned_atlas_index, offset_x, offset_y, div_x, div_y, island_padding | |
) | |
# And get the unique per-triangle UV coordinates | |
return _get_unique_face_uv(placed_uv) | |