Spaces:
Sleeping
Sleeping
import pytest | |
import torch | |
from chroma.models.graph_design import GraphDesign, ProteinTraversalSpatial | |
def model(): | |
model = GraphDesign(predict_S_marginals=True, predict_S_potts=True) | |
model.eval() | |
return model | |
def test_sequential_decoding(model, XCS): | |
"""Test that the sequential and parallelized decoding of GNN agree.""" | |
from chroma.data import xcs | |
X, C, S = XCS | |
permute_idx = torch.argsort(torch.randn_like(C.float()), dim=-1) | |
# Fix a permutation | |
scores_parallel = model(X, C, S, permute_idx=permute_idx) | |
_, _, _, scores_sequential = model.sample( | |
X, C, S, permute_idx=permute_idx, clamped=True, return_scores=True | |
) | |
assert torch.allclose( | |
scores_parallel["logp_S"], scores_sequential["logp_S"], atol=1e-3 | |
) | |
assert torch.allclose( | |
scores_parallel["logp_chi"], scores_sequential["logp_chi"], atol=1e-3 | |
) | |
# =============Fix a permutation ======== | |
X_sample, S_sample, _, scores_sequential = model.sample( | |
X, C, S, permute_idx=permute_idx, clamped=False, return_scores=True | |
) | |
scores_parallel = model(X_sample, C, S_sample, permute_idx=permute_idx) | |
assert torch.allclose( | |
scores_parallel["logp_S"], scores_sequential["logp_S"], atol=1e-3 | |
) | |
assert torch.allclose( | |
scores_parallel["logp_chi"], scores_sequential["logp_chi"], atol=1e-3 | |
) | |
return | |
def test_deterministic_traversal(XCS): | |
"""Check deterministic flag on ProteinTraversalSpatial module.""" | |
traversal = ProteinTraversalSpatial(deterministic=True) | |
X, C, _ = XCS | |
permute_idx = traversal(X, C) | |
permute_idx_2 = traversal(X, C) | |
assert torch.allclose(permute_idx, permute_idx_2) | |
return | |
def test_graph_design_outputs(model, XCS): | |
"""Smoke test all GraphDesign outputs.""" | |
X, C, S = XCS | |
outputs = model(X, C, S) | |
for key in ["logp_S", "logp_S_marginals", "logp_S_potts"]: | |
assert outputs[key].shape == X.shape[:2] | |
assert torch.allclose(outputs["X_noise"], X) | |
for key in ["chi", "logp_chi"]: | |
assert outputs[key].shape[:-1] == X.shape[:2] and outputs[key].shape[-1] == 4 | |