File size: 4,725 Bytes
ce7bf5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
from pathlib import Path

import numpy as np
import pytest
import torch
import torch.nn.functional as F

import chroma
from chroma.data import xcs
from chroma.data.protein import Protein
from chroma.layers.structure import backbone, conditioners, rmsd, symmetry
from chroma.models.graph_backbone import GraphBackbone
from chroma.models.procap import ProteinCaption


@pytest.fixture(scope="session")
def XCO():
    repo = Path(chroma.__file__).parent.parent
    test_cif = str(Path(repo, "tests", "resources", "6wgl.cif"))
    protein = Protein.from_CIF(test_cif)
    X, C, S = protein.to_XCS()
    X.requires_grad = True
    O = F.one_hot(S, 20)
    return X, C, O


@pytest.fixture(scope="session")
def protein():
    pdb_id = "1drf"
    protein = Protein.from_PDBID(pdb_id, canonicalize=True)
    protein.sys.save_selection(gti=list(range(15)), selname="clamp")
    protein.sys.save_selection(gti=list(range(15, 25)), selname="semirigid")
    return protein


@pytest.fixture
def test_conditioner_pointgroup_conditioner(XCO):
    X, C, O = XCO
    conditioner = conditioners.SymmetryConditioner(
        G=symmetry.get_point_group("I"), num_chain_neighbors=3
    )
    X_constrained, _, _, _, _ = conditioner(X, C, O, 0.0, 0.5)
    return conditioner, X_constrained, C


@pytest.fixture
def test_conditioner_screw_conditioner(XCO):
    X, C, O = XCO
    conditioner = conditioners.ScrewConditioner(theta=np.pi / 4, tz=5.0, M=10)
    X_constrained, _, _, _, _ = conditioner(X, C, O, 0.0, 0.5)
    return conditioner, X_constrained, C


@pytest.fixture
def test_conditioner_Rg_conditioner(XCO):
    X, C, O = XCO
    conditioner = conditioners.RgConditioner()
    conditioner(X, C, O, 0.0, 0.5)
    return conditioner, X, C


@pytest.fixture
def test_conditioner_symmetry_and_substructure(protein):
    bb_model = GraphBackbone(dim_nodes=16, dim_edges=16)
    protein.get_mask("namesel clamp")
    sub_conditioner = conditioners.SubstructureConditioner(
        protein, bb_model, "namesel clamp"
    )

    sym_conditioner = conditioners.SymmetryConditioner(
        G=symmetry.get_point_group("C_3"), num_chain_neighbors=1, freeze_com=True
    )

    composed_conditioner = conditioners.ComposedConditioner(
        [sub_conditioner, sym_conditioner]
    )

    X, C, S = protein.to_XCS()
    X.requires_grad = True
    O = F.one_hot(S, 20)
    X_constrained, _, _, _, _ = composed_conditioner(X, C, O, 0.0, torch.tensor([0.0]))
    return composed_conditioner, X_constrained, C


@pytest.fixture
def test_conditioner_substructure_conditioner(protein):
    aligner = rmsd.BackboneRMSD()
    bb_model = GraphBackbone(dim_nodes=16, dim_edges=16)
    X, C, S = protein.to_XCS()
    O = F.one_hot(S, 20)
    conditioner = conditioners.SubstructureConditioner(
        protein, bb_model, "namesel clamp"
    )
    X_conditioned, _, _, _, _ = conditioner(
        torch.randn_like(X), C, O, 0.0, torch.tensor([0.0])
    )
    D = protein.get_mask("namesel clamp")
    _, rmsd1 = aligner.align(X_conditioned, X, D)
    assert rmsd1.isclose(torch.tensor(0.0), atol=1e-1)

    return conditioner, X, C


@pytest.fixture
def test_conditioner_procap_conditioner(XCO):
    model = ProteinCaption()
    X, C, O = XCO
    conditioner = conditioners.ProCapConditioner("Test caption", -1, model=model)
    conditioner(X, C, O, 0, 0.5)
    return conditioner, X, C


def collect_conditioners():
    return [v for k, v in globals().items() if k.startswith("test_conditioner_")]


@pytest.fixture(params=["globular"])
def gaussian_noise(request):
    from chroma.layers.structure.diffusion import DiffusionChainCov

    covariance_model = request.param
    return DiffusionChainCov(
        covariance_model=covariance_model,
        complex_scaling=False,
        noise_schedule="log_snr",
    )


@pytest.mark.parametrize("conditioner", collect_conditioners())
def test_sampling(gaussian_noise, conditioner, request):
    conditioner_cls, X_native, C = request.getfixturevalue(conditioner.__name__)

    def X0_func(X, C, t):
        return X_native

    out = gaussian_noise.sample_sde(
        X0_func=X0_func, C=C, X_init=None, N=2, conditioner=conditioner_cls
    )


def test_proclass_conditioner(protein):
    """Smoke test for secondary structure conditioning"""
    SECONDARY_STRUCTURE = "CCEEEEEEEETTTTECTTTTTTTTCCCHHHHHHHHHHHHCCCTTTTEEEEEECHHHHHHCTGGTTTTTTTEEEEETTTTTTTTTTTCEEECTHHHHHHHHHCHGHGGHCCEEEEEECHHHHHHHHHCTCEEEEEEEEETTCCCTTEECCCCTGGGTEEETETTTTTCCEEEETTEEEEEEEEEEEC"
    X, C, S = protein.to_XCS()
    X.detach()
    X.requires_grad = True
    O = F.one_hot(S, 20)
    conditioner = conditioners.ProClassConditioner(
        "secondary_structure", SECONDARY_STRUCTURE, device="cpu"
    )
    conditioner(X, C, O, 0, 0.5)