Hukuna's picture
Upload 221 files
ce7bf5b verified
raw
history blame
4.24 kB
import pytest
import torch
from chroma.layers.structure.geometry import rotations_from_quaternions
from chroma.layers.structure.transforms import (
average_transforms,
collect_neighbor_transforms,
compose_inner_transforms,
compose_transforms,
compose_translation,
equilibrate_transforms,
fuse_gaussians_isometric_plus_radial,
)
@pytest.fixture
def vec():
torch.manual_seed(0)
return torch.rand(3)
@pytest.fixture
def rotations():
torch.manual_seed(0)
q = torch.rand(2, 4)
return rotations_from_quaternions(q, normalize=True).unbind()
@pytest.fixture
def translations():
torch.manual_seed(0)
return torch.rand(2, 3).unbind()
def test_compose_transforms(vec, rotations, translations):
R_a, R_b = rotations
t_a, t_b = translations
inter = R_b @ vec + t_b
result = R_a @ inter + t_a
R_composed, t_composed = compose_transforms(R_a, t_a, R_b, t_b)
assert torch.allclose(result, R_composed @ vec + t_composed)
def test_compose_translation(vec, rotations, translations):
R_a, _ = rotations
t_a, t_b = translations
inter = vec + t_b
result = R_a @ inter + t_a
t_composed = compose_translation(R_a, t_a, t_b)
assert torch.allclose(result, R_a @ vec + t_composed)
def test_compose_inner_transforms(vec, rotations, translations):
R_a, R_b = rotations
t_a, t_b = translations
R_a_inv = torch.inverse(R_a)
inter = R_b @ vec + t_b
result = R_a_inv @ (inter - t_a)
R_composed, t_composed = compose_inner_transforms(R_a, t_a, R_b, t_b)
# bump up tolerance because of matrix inversion
assert torch.allclose(result, R_composed @ vec + t_composed, atol=1e-3, rtol=1e-2)
def test_fuse_gaussians_isometric_plus_radial(vec):
p_iso = torch.tensor([0.3, 0.7])
p_rad = torch.zeros_like(p_iso)
x = torch.stack([vec, 2 * vec])
direction = torch.zeros_like(x)
x_fused, P_fused = fuse_gaussians_isometric_plus_radial(
x, p_iso, p_rad, direction, 0
)
assert torch.allclose((p_iso[0] + 2 * p_iso[1]) * vec, P_fused @ x_fused)
def test_collect_neighbor_transforms(rotations, translations):
R_i = torch.stack(rotations).unsqueeze(0)
t_i = torch.stack(translations).unsqueeze(0)
edge_idx = torch.LongTensor([[1], [0]]).unsqueeze(0)
R_j, t_j = collect_neighbor_transforms(R_i, t_i, edge_idx)
assert torch.allclose(R_j, torch.flip(R_i, [1]).unsqueeze(2))
assert torch.allclose(t_j, torch.flip(t_i, [1]).unsqueeze(2))
def test_equilibrate_transforms(rotations, translations):
R_i = torch.stack(rotations).unsqueeze(0)
t_i = torch.stack(translations).unsqueeze(0)
R_ji = torch.eye(3).expand(1, 2, 1, 3, 3)
t_ji = torch.zeros(1, 2, 1, 3)
logit_ij = torch.ones(1, 2, 1, 1)
mask_ij = torch.ones(1, 2, 1)
edge_idx = torch.LongTensor([[1], [0]]).unsqueeze(0)
# two transforms on nodes that are each other's neighbors, so a single
# iteration will just swap the transforms
R_eq, t_eq = equilibrate_transforms(
R_i, t_i, R_ji, t_ji, logit_ij, mask_ij, edge_idx, iterations=1
)
assert torch.allclose(R_eq, torch.flip(R_i, [1]), atol=1e-3, rtol=1e-2)
assert torch.allclose(t_eq, torch.flip(t_i, [1]), atol=1e-3, rtol=1e-2)
# two iterations moves the transforms back to themselves
R_eq, t_eq = equilibrate_transforms(
R_i, t_i, R_ji, t_ji, logit_ij, mask_ij, edge_idx, iterations=2
)
assert torch.allclose(R_eq, R_i, atol=1e-3, rtol=1e-2)
assert torch.allclose(t_eq, t_i, atol=1e-3, rtol=1e-2)
def test_average_transforms(rotations, translations):
R = torch.stack([rotations[0], torch.eye(3)])
t = torch.stack([translations[0], torch.zeros(3)])
w = torch.ones(2, 2)
mask = torch.ones(2)
# average of a transform with the identity is "half" the transform
R_avg, t_avg = average_transforms(R, t, w, mask, dim=0, dither=False)
R_total_fromavg, _ = compose_transforms(
R_avg, torch.zeros(3), R_avg, torch.zeros(3)
)
_, t_total_fromavg = compose_transforms(torch.eye(3), t_avg, torch.eye(3), t_avg)
assert torch.allclose(R_total_fromavg, R[0], atol=1e-3, rtol=1e-2)
assert torch.allclose(t_total_fromavg, t[0], atol=1e-3, rtol=1e-2)