File size: 4,411 Bytes
ce7bf5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
from pathlib import Path

import numpy as np
import pytest
import torch
import torch.nn.functional as F

import chroma
from chroma.data import Protein
from chroma.layers import graph
from chroma.layers.structure import hbonds, protein_graph


@pytest.fixture(scope="session")
def XCS():
    repo = Path(chroma.__file__).parent.parent
    pdb_id = "6wgl"
    test_cif = str(Path(repo, "tests", "resources", "6wgl.cif"))
    X, C, S = Protein(test_cif).to_XCS()
    return X, C, S, pdb_id


def test_backbone_hbonds(XCS, debug_plot=False):
    X, C, S, pdb_id = XCS

    bb_hbonds = hbonds.BackboneHBonds()

    # Build Graph
    graph_builder = protein_graph.ProteinGraph()
    edge_idx, mask_ij = graph_builder(X, C)
    hb, mask_hb, H_i = bb_hbonds(X, C, edge_idx, mask_ij)
    hb_dense = graph.scatter_edges(hb[..., None], edge_idx)[..., 0]

    if debug_plot:
        if False:
            H = hb_dense[0, :, :].data.numpy()
            from matplotlib import pyplot as plt

            plt.matshow(H)
            plt.show()

        # Build
        rgb = (0.3, 0.7, 0.1)
        with open(f"viz_hbonds_{pdb_id}.pml", "w") as f:
            f.write(
                "delete all\n"
                f"fetch {pdb_id}\n"
                f"hide everything, {pdb_id}\n"
                "show sticks, bb.\n"
                "color white, all\n"
                "color atomic, (not elem C)\n"
                "h_add bb.\n"
                "distance hbonds_pymol, don. and bb., acc. and bb., 3.6, mode=2\n"
                "hide labels\n"
            )
            cgo_list = [protein_graph._cgo_color(rgb)]
            for i in range(edge_idx.size(1)):
                for j_idx in range(edge_idx.size(2)):
                    if hb[0, i, j_idx] > 0:
                        j = edge_idx[0, i, j_idx]
                        cgo_list.append(
                            protein_graph._cgo_cylinder(
                                H_i[0, i, :], X[0, j, 3, :], radius=0.08, rgb=rgb
                            )
                        )
                        cgo_list.append(
                            protein_graph._cgo_sphere(H_i[0, i, :], radius=0.3)
                        )
            cgo_str = " + ".join(cgo_list)
            f.write(f'cmd.load_cgo({cgo_str}, "hbonds_pytorch", 1)\n')

    # These hydrogen bonds were manually spot checked for 6wgl
    # in Pymol using the above script. We don't count i-i+2 and
    # there appear to be subtle orientation dependent, but
    # SS-dependent calls agree well
    assert hb_dense.sum().item() == 303


def test_loss_hbb(XCS, debug=False):
    X, C, S, pdb_id = XCS
    loss_hbb = hbonds.LossBackboneHBonds()

    torch.manual_seed(1.0)
    X_noise = X + torch.randn_like(X)
    recovery_local, recovery_nonlocal, error_co = loss_hbb(X_noise, X, C)
    assert recovery_local.mean().item() < 1.0
    assert recovery_nonlocal.mean().item() < 1.0
    assert error_co > 0.0

    recovery_local, recovery_nonlocal, error_co = loss_hbb(X, X, C)
    assert recovery_local.mean().item() == pytest.approx(1.0, 1e-2)
    assert recovery_nonlocal.mean().item() == pytest.approx(1.0, 1e-2)
    assert error_co.mean().item() == pytest.approx(0.0, 1e-2)

    if debug:
        # This
        from chroma.layers.structure import diffusion

        noise = diffusion.DiffusionChainCov(complex_scaling=True)

        T = np.linspace(0, 1, 100)
        R_local = []
        R_nonlocal = []
        for t in T:
            X_noise = noise(X, C, t=t)
            recovery_local, recovery_nonlocal, error_co = loss_hbb(X_noise, X, C)
            R_local.append(recovery_local.mean().item())
            R_nonlocal.append(recovery_nonlocal.mean().item())
        A = noise.noise_schedule.alpha(T.tolist()).data.numpy().flatten()

        from matplotlib import pyplot as plt

        plt.subplot(1, 2, 1)
        plt.plot(T, R_local, label="Local H-Bonds")
        plt.plot(T, R_nonlocal, label="Nonlocal H-Bonds")
        plt.xlim([0, 1])
        plt.ylim([0, 1])
        plt.xlabel("t")
        plt.ylabel("Recovery")
        plt.legend()
        plt.grid()
        plt.subplot(1, 2, 2)
        plt.plot(A, R_local, label="Local H-Bonds")
        plt.plot(A, R_nonlocal, label="Nonlocal H-Bonds")
        plt.xlim([0, 1])
        plt.ylim([0, 1])
        plt.xlabel("alpha")
        plt.ylabel("Recovery")
        plt.legend()
        plt.grid()
        plt.show()
    return