File size: 2,157 Bytes
e9e75df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
import pytest
import torch

from chroma.models.graph_design import GraphDesign, ProteinTraversalSpatial


@pytest.fixture
def model():
    model = GraphDesign(predict_S_marginals=True, predict_S_potts=True)
    model.eval()
    return model


def test_sequential_decoding(model, XCS):
    """Test that the sequential and parallelized decoding of GNN agree."""

    from chroma.data import xcs

    X, C, S = XCS
    permute_idx = torch.argsort(torch.randn_like(C.float()), dim=-1)

    # Fix a permutation
    scores_parallel = model(X, C, S, permute_idx=permute_idx)

    _, _, _, scores_sequential = model.sample(
        X, C, S, permute_idx=permute_idx, clamped=True, return_scores=True
    )

    assert torch.allclose(
        scores_parallel["logp_S"], scores_sequential["logp_S"], atol=1e-3
    )
    assert torch.allclose(
        scores_parallel["logp_chi"], scores_sequential["logp_chi"], atol=1e-3
    )

    # =============Fix a permutation ========
    X_sample, S_sample, _, scores_sequential = model.sample(
        X, C, S, permute_idx=permute_idx, clamped=False, return_scores=True
    )
    scores_parallel = model(X_sample, C, S_sample, permute_idx=permute_idx)

    assert torch.allclose(
        scores_parallel["logp_S"], scores_sequential["logp_S"], atol=1e-3
    )
    assert torch.allclose(
        scores_parallel["logp_chi"], scores_sequential["logp_chi"], atol=1e-3
    )
    return


def test_deterministic_traversal(XCS):
    """Check deterministic flag on ProteinTraversalSpatial module."""
    traversal = ProteinTraversalSpatial(deterministic=True)
    X, C, _ = XCS
    permute_idx = traversal(X, C)
    permute_idx_2 = traversal(X, C)
    assert torch.allclose(permute_idx, permute_idx_2)
    return


def test_graph_design_outputs(model, XCS):
    """Smoke test all GraphDesign outputs."""
    X, C, S = XCS
    outputs = model(X, C, S)
    for key in ["logp_S", "logp_S_marginals", "logp_S_potts"]:
        assert outputs[key].shape == X.shape[:2]
    assert torch.allclose(outputs["X_noise"], X)
    for key in ["chi", "logp_chi"]:
        assert outputs[key].shape[:-1] == X.shape[:2] and outputs[key].shape[-1] == 4