File size: 1,310 Bytes
26ce2a9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
import torch
from liegroups.torch import utils
def test_isclose():
tol = 1e-6
mat = torch.Tensor([0, 1, tol, 10 * tol, 0.1 * tol])
ans = torch.ByteTensor([1, 0, 0, 0, 1])
assert (utils.isclose(mat, 0., tol=tol) == ans).all()
def test_allclose():
tol = 1e-6
mat_good = torch.Tensor([0.1 * tol, 0.01 * tol, 0, 0, 0])
mat_bad = torch.Tensor([0, 1, tol, 10 * tol, 0.1 * tol])
assert utils.allclose(mat_good, 0., tol=tol)
assert not utils.allclose(mat_bad, 0., tol=tol)
def test_outer():
vec1 = torch.Tensor([1, 2, 3])
vec2 = torch.Tensor([0, 1, 2])
assert (utils.outer(vec1, vec2) == torch.mm(
vec1.unsqueeze(dim=1), vec2.unsqueeze(dim=0))).all()
vecs1 = torch.Tensor([[1, 2, 3], [4, 5, 6]])
vecs2 = torch.Tensor([[0, 1, 2], [3, 4, 5]])
assert (utils.outer(vecs1, vecs2) == torch.bmm(
vecs1.unsqueeze(dim=2), vecs2.unsqueeze(dim=1))).all()
def test_trace():
mat = torch.arange(1, 10).view(3, 3)
assert utils.trace(mat)[0] == torch.trace(mat)
mats = torch.cat([torch.arange(1, 10).view(1, 3, 3),
torch.arange(11, 20).view(1, 3, 3)], dim=0)
traces = utils.trace(mats)
assert len(traces) == 2 and \
traces[0] == torch.trace(mats[0]) and \
traces[1] == torch.trace(mats[1])
|