Spaces:
Sleeping
Sleeping
File size: 2,157 Bytes
e9e75df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 |
import pytest
import torch
from chroma.models.graph_design import GraphDesign, ProteinTraversalSpatial
@pytest.fixture
def model():
model = GraphDesign(predict_S_marginals=True, predict_S_potts=True)
model.eval()
return model
def test_sequential_decoding(model, XCS):
"""Test that the sequential and parallelized decoding of GNN agree."""
from chroma.data import xcs
X, C, S = XCS
permute_idx = torch.argsort(torch.randn_like(C.float()), dim=-1)
# Fix a permutation
scores_parallel = model(X, C, S, permute_idx=permute_idx)
_, _, _, scores_sequential = model.sample(
X, C, S, permute_idx=permute_idx, clamped=True, return_scores=True
)
assert torch.allclose(
scores_parallel["logp_S"], scores_sequential["logp_S"], atol=1e-3
)
assert torch.allclose(
scores_parallel["logp_chi"], scores_sequential["logp_chi"], atol=1e-3
)
# =============Fix a permutation ========
X_sample, S_sample, _, scores_sequential = model.sample(
X, C, S, permute_idx=permute_idx, clamped=False, return_scores=True
)
scores_parallel = model(X_sample, C, S_sample, permute_idx=permute_idx)
assert torch.allclose(
scores_parallel["logp_S"], scores_sequential["logp_S"], atol=1e-3
)
assert torch.allclose(
scores_parallel["logp_chi"], scores_sequential["logp_chi"], atol=1e-3
)
return
def test_deterministic_traversal(XCS):
"""Check deterministic flag on ProteinTraversalSpatial module."""
traversal = ProteinTraversalSpatial(deterministic=True)
X, C, _ = XCS
permute_idx = traversal(X, C)
permute_idx_2 = traversal(X, C)
assert torch.allclose(permute_idx, permute_idx_2)
return
def test_graph_design_outputs(model, XCS):
"""Smoke test all GraphDesign outputs."""
X, C, S = XCS
outputs = model(X, C, S)
for key in ["logp_S", "logp_S_marginals", "logp_S_potts"]:
assert outputs[key].shape == X.shape[:2]
assert torch.allclose(outputs["X_noise"], X)
for key in ["chi", "logp_chi"]:
assert outputs[key].shape[:-1] == X.shape[:2] and outputs[key].shape[-1] == 4
|