Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from chroma.layers.structure.optimal_transport import ( | |
optimize_couplings_gw, | |
optimize_couplings_sinkhorn, | |
) | |
# test sinkhorn | |
def test_sinkhorn(): | |
C = torch.Tensor([[[1, 0, 0], [0, 0, 1], [0, 1, 0]]]) | |
assert torch.allclose( | |
optimize_couplings_sinkhorn(C).argmin(-1), torch.LongTensor([[0, 2, 1]]) | |
) | |
def test_gw(): | |
# TODO: need a nontrivial test | |
seed1 = torch.randn(4).abs() | |
adj1 = torch.outer(seed1, seed1) | |
Da = torch.stack([adj1, adj1]) | |
Db = torch.stack([adj1, adj1]) | |
optimize_couplings_gw(Da, Db, scale=2) | |