Spaces:
Sleeping
Sleeping
from pathlib import Path | |
import numpy as np | |
import pytest | |
import torch | |
import torch.nn.functional as F | |
import chroma | |
from chroma.data import Protein | |
from chroma.layers import graph | |
from chroma.layers.structure import hbonds, protein_graph | |
def XCS(): | |
repo = Path(chroma.__file__).parent.parent | |
pdb_id = "6wgl" | |
test_cif = str(Path(repo, "tests", "resources", "6wgl.cif")) | |
X, C, S = Protein(test_cif).to_XCS() | |
return X, C, S, pdb_id | |
def test_backbone_hbonds(XCS, debug_plot=False): | |
X, C, S, pdb_id = XCS | |
bb_hbonds = hbonds.BackboneHBonds() | |
# Build Graph | |
graph_builder = protein_graph.ProteinGraph() | |
edge_idx, mask_ij = graph_builder(X, C) | |
hb, mask_hb, H_i = bb_hbonds(X, C, edge_idx, mask_ij) | |
hb_dense = graph.scatter_edges(hb[..., None], edge_idx)[..., 0] | |
if debug_plot: | |
if False: | |
H = hb_dense[0, :, :].data.numpy() | |
from matplotlib import pyplot as plt | |
plt.matshow(H) | |
plt.show() | |
# Build | |
rgb = (0.3, 0.7, 0.1) | |
with open(f"viz_hbonds_{pdb_id}.pml", "w") as f: | |
f.write( | |
"delete all\n" | |
f"fetch {pdb_id}\n" | |
f"hide everything, {pdb_id}\n" | |
"show sticks, bb.\n" | |
"color white, all\n" | |
"color atomic, (not elem C)\n" | |
"h_add bb.\n" | |
"distance hbonds_pymol, don. and bb., acc. and bb., 3.6, mode=2\n" | |
"hide labels\n" | |
) | |
cgo_list = [protein_graph._cgo_color(rgb)] | |
for i in range(edge_idx.size(1)): | |
for j_idx in range(edge_idx.size(2)): | |
if hb[0, i, j_idx] > 0: | |
j = edge_idx[0, i, j_idx] | |
cgo_list.append( | |
protein_graph._cgo_cylinder( | |
H_i[0, i, :], X[0, j, 3, :], radius=0.08, rgb=rgb | |
) | |
) | |
cgo_list.append( | |
protein_graph._cgo_sphere(H_i[0, i, :], radius=0.3) | |
) | |
cgo_str = " + ".join(cgo_list) | |
f.write(f'cmd.load_cgo({cgo_str}, "hbonds_pytorch", 1)\n') | |
# These hydrogen bonds were manually spot checked for 6wgl | |
# in Pymol using the above script. We don't count i-i+2 and | |
# there appear to be subtle orientation dependent, but | |
# SS-dependent calls agree well | |
assert hb_dense.sum().item() == 303 | |
def test_loss_hbb(XCS, debug=False): | |
X, C, S, pdb_id = XCS | |
loss_hbb = hbonds.LossBackboneHBonds() | |
torch.manual_seed(1.0) | |
X_noise = X + torch.randn_like(X) | |
recovery_local, recovery_nonlocal, error_co = loss_hbb(X_noise, X, C) | |
assert recovery_local.mean().item() < 1.0 | |
assert recovery_nonlocal.mean().item() < 1.0 | |
assert error_co > 0.0 | |
recovery_local, recovery_nonlocal, error_co = loss_hbb(X, X, C) | |
assert recovery_local.mean().item() == pytest.approx(1.0, 1e-2) | |
assert recovery_nonlocal.mean().item() == pytest.approx(1.0, 1e-2) | |
assert error_co.mean().item() == pytest.approx(0.0, 1e-2) | |
if debug: | |
# This | |
from chroma.layers.structure import diffusion | |
noise = diffusion.DiffusionChainCov(complex_scaling=True) | |
T = np.linspace(0, 1, 100) | |
R_local = [] | |
R_nonlocal = [] | |
for t in T: | |
X_noise = noise(X, C, t=t) | |
recovery_local, recovery_nonlocal, error_co = loss_hbb(X_noise, X, C) | |
R_local.append(recovery_local.mean().item()) | |
R_nonlocal.append(recovery_nonlocal.mean().item()) | |
A = noise.noise_schedule.alpha(T.tolist()).data.numpy().flatten() | |
from matplotlib import pyplot as plt | |
plt.subplot(1, 2, 1) | |
plt.plot(T, R_local, label="Local H-Bonds") | |
plt.plot(T, R_nonlocal, label="Nonlocal H-Bonds") | |
plt.xlim([0, 1]) | |
plt.ylim([0, 1]) | |
plt.xlabel("t") | |
plt.ylabel("Recovery") | |
plt.legend() | |
plt.grid() | |
plt.subplot(1, 2, 2) | |
plt.plot(A, R_local, label="Local H-Bonds") | |
plt.plot(A, R_nonlocal, label="Nonlocal H-Bonds") | |
plt.xlim([0, 1]) | |
plt.ylim([0, 1]) | |
plt.xlabel("alpha") | |
plt.ylabel("Recovery") | |
plt.legend() | |
plt.grid() | |
plt.show() | |
return | |