Spaces:
Sleeping
Sleeping
File size: 4,725 Bytes
ce7bf5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 |
from pathlib import Path
import numpy as np
import pytest
import torch
import torch.nn.functional as F
import chroma
from chroma.data import xcs
from chroma.data.protein import Protein
from chroma.layers.structure import backbone, conditioners, rmsd, symmetry
from chroma.models.graph_backbone import GraphBackbone
from chroma.models.procap import ProteinCaption
@pytest.fixture(scope="session")
def XCO():
repo = Path(chroma.__file__).parent.parent
test_cif = str(Path(repo, "tests", "resources", "6wgl.cif"))
protein = Protein.from_CIF(test_cif)
X, C, S = protein.to_XCS()
X.requires_grad = True
O = F.one_hot(S, 20)
return X, C, O
@pytest.fixture(scope="session")
def protein():
pdb_id = "1drf"
protein = Protein.from_PDBID(pdb_id, canonicalize=True)
protein.sys.save_selection(gti=list(range(15)), selname="clamp")
protein.sys.save_selection(gti=list(range(15, 25)), selname="semirigid")
return protein
@pytest.fixture
def test_conditioner_pointgroup_conditioner(XCO):
X, C, O = XCO
conditioner = conditioners.SymmetryConditioner(
G=symmetry.get_point_group("I"), num_chain_neighbors=3
)
X_constrained, _, _, _, _ = conditioner(X, C, O, 0.0, 0.5)
return conditioner, X_constrained, C
@pytest.fixture
def test_conditioner_screw_conditioner(XCO):
X, C, O = XCO
conditioner = conditioners.ScrewConditioner(theta=np.pi / 4, tz=5.0, M=10)
X_constrained, _, _, _, _ = conditioner(X, C, O, 0.0, 0.5)
return conditioner, X_constrained, C
@pytest.fixture
def test_conditioner_Rg_conditioner(XCO):
X, C, O = XCO
conditioner = conditioners.RgConditioner()
conditioner(X, C, O, 0.0, 0.5)
return conditioner, X, C
@pytest.fixture
def test_conditioner_symmetry_and_substructure(protein):
bb_model = GraphBackbone(dim_nodes=16, dim_edges=16)
protein.get_mask("namesel clamp")
sub_conditioner = conditioners.SubstructureConditioner(
protein, bb_model, "namesel clamp"
)
sym_conditioner = conditioners.SymmetryConditioner(
G=symmetry.get_point_group("C_3"), num_chain_neighbors=1, freeze_com=True
)
composed_conditioner = conditioners.ComposedConditioner(
[sub_conditioner, sym_conditioner]
)
X, C, S = protein.to_XCS()
X.requires_grad = True
O = F.one_hot(S, 20)
X_constrained, _, _, _, _ = composed_conditioner(X, C, O, 0.0, torch.tensor([0.0]))
return composed_conditioner, X_constrained, C
@pytest.fixture
def test_conditioner_substructure_conditioner(protein):
aligner = rmsd.BackboneRMSD()
bb_model = GraphBackbone(dim_nodes=16, dim_edges=16)
X, C, S = protein.to_XCS()
O = F.one_hot(S, 20)
conditioner = conditioners.SubstructureConditioner(
protein, bb_model, "namesel clamp"
)
X_conditioned, _, _, _, _ = conditioner(
torch.randn_like(X), C, O, 0.0, torch.tensor([0.0])
)
D = protein.get_mask("namesel clamp")
_, rmsd1 = aligner.align(X_conditioned, X, D)
assert rmsd1.isclose(torch.tensor(0.0), atol=1e-1)
return conditioner, X, C
@pytest.fixture
def test_conditioner_procap_conditioner(XCO):
model = ProteinCaption()
X, C, O = XCO
conditioner = conditioners.ProCapConditioner("Test caption", -1, model=model)
conditioner(X, C, O, 0, 0.5)
return conditioner, X, C
def collect_conditioners():
return [v for k, v in globals().items() if k.startswith("test_conditioner_")]
@pytest.fixture(params=["globular"])
def gaussian_noise(request):
from chroma.layers.structure.diffusion import DiffusionChainCov
covariance_model = request.param
return DiffusionChainCov(
covariance_model=covariance_model,
complex_scaling=False,
noise_schedule="log_snr",
)
@pytest.mark.parametrize("conditioner", collect_conditioners())
def test_sampling(gaussian_noise, conditioner, request):
conditioner_cls, X_native, C = request.getfixturevalue(conditioner.__name__)
def X0_func(X, C, t):
return X_native
out = gaussian_noise.sample_sde(
X0_func=X0_func, C=C, X_init=None, N=2, conditioner=conditioner_cls
)
def test_proclass_conditioner(protein):
"""Smoke test for secondary structure conditioning"""
SECONDARY_STRUCTURE = "CCEEEEEEEETTTTECTTTTTTTTCCCHHHHHHHHHHHHCCCTTTTEEEEEECHHHHHHCTGGTTTTTTTEEEEETTTTTTTTTTTCEEECTHHHHHHHHHCHGHGGHCCEEEEEECHHHHHHHHHCTCEEEEEEEEETTCCCTTEECCCCTGGGTEEETETTTTTCCEEEETTEEEEEEEEEEEC"
X, C, S = protein.to_XCS()
X.detach()
X.requires_grad = True
O = F.one_hot(S, 20)
conditioner = conditioners.ProClassConditioner(
"secondary_structure", SECONDARY_STRUCTURE, device="cpu"
)
conditioner(X, C, O, 0, 0.5)
|