Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from chroma.layers.structure.geometry import rotations_from_quaternions | |
from chroma.layers.structure.transforms import ( | |
average_transforms, | |
collect_neighbor_transforms, | |
compose_inner_transforms, | |
compose_transforms, | |
compose_translation, | |
equilibrate_transforms, | |
fuse_gaussians_isometric_plus_radial, | |
) | |
def vec(): | |
torch.manual_seed(0) | |
return torch.rand(3) | |
def rotations(): | |
torch.manual_seed(0) | |
q = torch.rand(2, 4) | |
return rotations_from_quaternions(q, normalize=True).unbind() | |
def translations(): | |
torch.manual_seed(0) | |
return torch.rand(2, 3).unbind() | |
def test_compose_transforms(vec, rotations, translations): | |
R_a, R_b = rotations | |
t_a, t_b = translations | |
inter = R_b @ vec + t_b | |
result = R_a @ inter + t_a | |
R_composed, t_composed = compose_transforms(R_a, t_a, R_b, t_b) | |
assert torch.allclose(result, R_composed @ vec + t_composed) | |
def test_compose_translation(vec, rotations, translations): | |
R_a, _ = rotations | |
t_a, t_b = translations | |
inter = vec + t_b | |
result = R_a @ inter + t_a | |
t_composed = compose_translation(R_a, t_a, t_b) | |
assert torch.allclose(result, R_a @ vec + t_composed) | |
def test_compose_inner_transforms(vec, rotations, translations): | |
R_a, R_b = rotations | |
t_a, t_b = translations | |
R_a_inv = torch.inverse(R_a) | |
inter = R_b @ vec + t_b | |
result = R_a_inv @ (inter - t_a) | |
R_composed, t_composed = compose_inner_transforms(R_a, t_a, R_b, t_b) | |
# bump up tolerance because of matrix inversion | |
assert torch.allclose(result, R_composed @ vec + t_composed, atol=1e-3, rtol=1e-2) | |
def test_fuse_gaussians_isometric_plus_radial(vec): | |
p_iso = torch.tensor([0.3, 0.7]) | |
p_rad = torch.zeros_like(p_iso) | |
x = torch.stack([vec, 2 * vec]) | |
direction = torch.zeros_like(x) | |
x_fused, P_fused = fuse_gaussians_isometric_plus_radial( | |
x, p_iso, p_rad, direction, 0 | |
) | |
assert torch.allclose((p_iso[0] + 2 * p_iso[1]) * vec, P_fused @ x_fused) | |
def test_collect_neighbor_transforms(rotations, translations): | |
R_i = torch.stack(rotations).unsqueeze(0) | |
t_i = torch.stack(translations).unsqueeze(0) | |
edge_idx = torch.LongTensor([[1], [0]]).unsqueeze(0) | |
R_j, t_j = collect_neighbor_transforms(R_i, t_i, edge_idx) | |
assert torch.allclose(R_j, torch.flip(R_i, [1]).unsqueeze(2)) | |
assert torch.allclose(t_j, torch.flip(t_i, [1]).unsqueeze(2)) | |
def test_equilibrate_transforms(rotations, translations): | |
R_i = torch.stack(rotations).unsqueeze(0) | |
t_i = torch.stack(translations).unsqueeze(0) | |
R_ji = torch.eye(3).expand(1, 2, 1, 3, 3) | |
t_ji = torch.zeros(1, 2, 1, 3) | |
logit_ij = torch.ones(1, 2, 1, 1) | |
mask_ij = torch.ones(1, 2, 1) | |
edge_idx = torch.LongTensor([[1], [0]]).unsqueeze(0) | |
# two transforms on nodes that are each other's neighbors, so a single | |
# iteration will just swap the transforms | |
R_eq, t_eq = equilibrate_transforms( | |
R_i, t_i, R_ji, t_ji, logit_ij, mask_ij, edge_idx, iterations=1 | |
) | |
assert torch.allclose(R_eq, torch.flip(R_i, [1]), atol=1e-3, rtol=1e-2) | |
assert torch.allclose(t_eq, torch.flip(t_i, [1]), atol=1e-3, rtol=1e-2) | |
# two iterations moves the transforms back to themselves | |
R_eq, t_eq = equilibrate_transforms( | |
R_i, t_i, R_ji, t_ji, logit_ij, mask_ij, edge_idx, iterations=2 | |
) | |
assert torch.allclose(R_eq, R_i, atol=1e-3, rtol=1e-2) | |
assert torch.allclose(t_eq, t_i, atol=1e-3, rtol=1e-2) | |
def test_average_transforms(rotations, translations): | |
R = torch.stack([rotations[0], torch.eye(3)]) | |
t = torch.stack([translations[0], torch.zeros(3)]) | |
w = torch.ones(2, 2) | |
mask = torch.ones(2) | |
# average of a transform with the identity is "half" the transform | |
R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False) | |
R_total_fromavg, _ = compose_transforms( | |
R_avg, torch.zeros(3), R_avg, torch.zeros(3) | |
) | |
_, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg) | |
assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2) | |
assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2) | |