Spaces:
Sleeping
Sleeping
File size: 4,411 Bytes
e9e75df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
from pathlib import Path
import numpy as np
import pytest
import torch
import torch.nn.functional as F
import chroma
from chroma.data import Protein
from chroma.layers import graph
from chroma.layers.structure import hbonds, protein_graph
@pytest.fixture(scope="session")
def XCS():
repo = Path(chroma.__file__).parent.parent
pdb_id = "6wgl"
test_cif = str(Path(repo, "tests", "resources", "6wgl.cif"))
X, C, S = Protein(test_cif).to_XCS()
return X, C, S, pdb_id
def test_backbone_hbonds(XCS, debug_plot=False):
X, C, S, pdb_id = XCS
bb_hbonds = hbonds.BackboneHBonds()
# Build Graph
graph_builder = protein_graph.ProteinGraph()
edge_idx, mask_ij = graph_builder(X, C)
hb, mask_hb, H_i = bb_hbonds(X, C, edge_idx, mask_ij)
hb_dense = graph.scatter_edges(hb[..., None], edge_idx)[..., 0]
if debug_plot:
if False:
H = hb_dense[0, :, :].data.numpy()
from matplotlib import pyplot as plt
plt.matshow(H)
plt.show()
# Build
rgb = (0.3, 0.7, 0.1)
with open(f"viz_hbonds_{pdb_id}.pml", "w") as f:
f.write(
"delete all\n"
f"fetch {pdb_id}\n"
f"hide everything, {pdb_id}\n"
"show sticks, bb.\n"
"color white, all\n"
"color atomic, (not elem C)\n"
"h_add bb.\n"
"distance hbonds_pymol, don. and bb., acc. and bb., 3.6, mode=2\n"
"hide labels\n"
)
cgo_list = [protein_graph._cgo_color(rgb)]
for i in range(edge_idx.size(1)):
for j_idx in range(edge_idx.size(2)):
if hb[0, i, j_idx] > 0:
j = edge_idx[0, i, j_idx]
cgo_list.append(
protein_graph._cgo_cylinder(
H_i[0, i, :], X[0, j, 3, :], radius=0.08, rgb=rgb
)
)
cgo_list.append(
protein_graph._cgo_sphere(H_i[0, i, :], radius=0.3)
)
cgo_str = " + ".join(cgo_list)
f.write(f'cmd.load_cgo({cgo_str}, "hbonds_pytorch", 1)\n')
# These hydrogen bonds were manually spot checked for 6wgl
# in Pymol using the above script. We don't count i-i+2 and
# there appear to be subtle orientation dependent, but
# SS-dependent calls agree well
assert hb_dense.sum().item() == 303
def test_loss_hbb(XCS, debug=False):
X, C, S, pdb_id = XCS
loss_hbb = hbonds.LossBackboneHBonds()
torch.manual_seed(1.0)
X_noise = X + torch.randn_like(X)
recovery_local, recovery_nonlocal, error_co = loss_hbb(X_noise, X, C)
assert recovery_local.mean().item() < 1.0
assert recovery_nonlocal.mean().item() < 1.0
assert error_co > 0.0
recovery_local, recovery_nonlocal, error_co = loss_hbb(X, X, C)
assert recovery_local.mean().item() == pytest.approx(1.0, 1e-2)
assert recovery_nonlocal.mean().item() == pytest.approx(1.0, 1e-2)
assert error_co.mean().item() == pytest.approx(0.0, 1e-2)
if debug:
# This
from chroma.layers.structure import diffusion
noise = diffusion.DiffusionChainCov(complex_scaling=True)
T = np.linspace(0, 1, 100)
R_local = []
R_nonlocal = []
for t in T:
X_noise = noise(X, C, t=t)
recovery_local, recovery_nonlocal, error_co = loss_hbb(X_noise, X, C)
R_local.append(recovery_local.mean().item())
R_nonlocal.append(recovery_nonlocal.mean().item())
A = noise.noise_schedule.alpha(T.tolist()).data.numpy().flatten()
from matplotlib import pyplot as plt
plt.subplot(1, 2, 1)
plt.plot(T, R_local, label="Local H-Bonds")
plt.plot(T, R_nonlocal, label="Nonlocal H-Bonds")
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel("t")
plt.ylabel("Recovery")
plt.legend()
plt.grid()
plt.subplot(1, 2, 2)
plt.plot(A, R_local, label="Local H-Bonds")
plt.plot(A, R_nonlocal, label="Nonlocal H-Bonds")
plt.xlim([0, 1])
plt.ylim([0, 1])
plt.xlabel("alpha")
plt.ylabel("Recovery")
plt.legend()
plt.grid()
plt.show()
return
|