Hukuna's picture
Upload 221 files
ce7bf5b verified
raw
history blame
8.92 kB
import numpy as np
import pytest
import torch
from chroma.data import Protein
from chroma.layers.structure.backbone import RigidTransformer
from chroma.layers.structure.rmsd import (
BackboneRMSD,
CrossRMSD,
LossFragmentPairRMSD,
LossFragmentRMSD,
LossNeighborhoodRMSD,
)
@pytest.fixture
def backbones():
bb1 = torch.tensor(
[
-5.68175,
-2.183,
3.27979,
-4.82875,
-3.256,
2.79379,
-3.34475,
-2.899,
2.79579,
-2.51375,
-3.697,
3.21979,
-3.01675,
-1.713,
2.29979,
-1.62875,
-1.289,
2.22979,
-0.95775,
-1.094,
3.58379,
0.20325,
-1.46,
3.75679,
-1.69375,
-0.547,
4.54479,
-1.16675,
-0.358,
5.88579,
-0.94175,
-1.732,
6.50679,
0.03125,
-1.943,
7.23679,
]
).reshape(-1, 12, 3)
bb2 = torch.tensor(
[
3.91725,
1.271,
-1.22921,
3.22825,
0.099,
-1.74321,
2.09025,
0.535,
-2.66521,
1.91025,
-0.018,
-3.74821,
1.34825,
1.553,
-2.23921,
0.24325,
2.085,
-3.02321,
0.76425,
2.518,
-4.38621,
0.10225,
2.315,
-5.41221,
1.96925,
3.085,
-4.38121,
2.61525,
3.562,
-5.59821,
3.27725,
2.453,
-6.40421,
4.07425,
2.713,
-7.30321,
]
).reshape(-1, 12, 3)
return bb1, bb2
def test_pairedRMSD(backbones):
bb1, bb2 = backbones
cross_rmsd = CrossRMSD()
predicted_rmsd = cross_rmsd.pairedRMSD(bb1, bb2)
assert torch.isclose(predicted_rmsd, torch.tensor(0.3542), rtol=1e-3)
def test_pairedRMSD_symeig(backbones):
bb1, bb2 = backbones
cross_rmsd = CrossRMSD(method="symeig")
predicted_rmsd = cross_rmsd.pairedRMSD(bb1, bb2)
assert torch.isclose(predicted_rmsd, torch.tensor(0.35), rtol=1e1)
def test_sample(backbones):
bb1, bb2 = backbones
cross_rmsd = CrossRMSD()
input_x = torch.cat([bb1, bb1])
predicted = cross_rmsd(input_x, input_x)
assert predicted.shape == (input_x.shape[0], input_x.shape[0])
assert torch.allclose(predicted, torch.zeros_like(predicted), atol=1e-1)
predicted = cross_rmsd(bb1, bb2)
assert all(torch.isclose(predicted, torch.tensor(0.35), rtol=1e-1))
def test_sample_symeigh(backbones):
bb1, bb2 = backbones
cross_rmsd = CrossRMSD(method="symeig")
input_x = torch.cat([bb1, bb1])
predicted = cross_rmsd(input_x, input_x)
assert torch.allclose(predicted, torch.zeros_like(predicted), atol=1e-2)
def test_backbone_rmsd(backbones):
bb1, bb2 = backbones
for method in ["symeig", "power"]:
backbone_rmsd = BackboneRMSD(method=method)
X, C, S = Protein("5imm").to_XCS()
rigid_transformer = RigidTransformer()
dX = torch.Tensor([[1, 4, 2]])
q = torch.Tensor([[0.5, 1, 0, 1]])
X_transform = rigid_transformer(X, dX, q)
X_transform_aligned, rmsd = backbone_rmsd.align(
X_transform, X, C, align_unmasked=True
)
assert not torch.allclose(X, X_transform, atol=1e-2)
assert torch.allclose(X, X_transform_aligned, atol=1e-2)
assert rmsd < 1e-2
def test_fragment_rmsd(debug=False):
X, C, S = Protein("1SHG").to_XCS()
loss_frags = LossFragmentRMSD()
X_noise = X + torch.randn_like(X)
rmsd = loss_frags(X, X, C)
rmsd_noised = loss_frags(X_noise, X, C)
assert rmsd.mean() < 1e-2
assert rmsd_noised.mean() > 1.0
if debug:
from chroma.layers.structure import diffusion
noise = diffusion.DiffusionChainCov(complex_scaling=True)
X_noise = noise(X, C, t=0.6)
rmsd, X_frag_target, X_frag_mobile, X_frag_mobile_align = loss_frags(
X_noise, X, C, return_coords=True
)
print(rmsd)
def _trajectory(X_frags):
B, I, _, _ = list(X_frags.shape)
X_frags = X_frags.reshape([B * I, -1, 4, 3])
X_trajectory = [X_t[None, ...] for X_t in X_frags.unbind(0)]
return X_trajectory
C = torch.ones([1, loss_frags.k])
X_trajectory_1 = _trajectory(X_frag_target)
X_trajectory_2 = _trajectory(X_frag_mobile_align)
X_trajectory_3 = _trajectory(X_frag_mobile)
# Fight pymol confusion
index = 10
X_trajectory_3 = [X_trajectory_1[index]] + X_trajectory_3
X_trajectory_2 = [X_trajectory_1[index]] + X_trajectory_2
X_trajectory_1 = [X_trajectory_1[index]] + X_trajectory_1
Protein.from_XCS_trajectory(X_trajectory_1, C, 0.0 * C).to_CIF(
"X_frag_target.cif"
)
Protein.from_XCS_trajectory(X_trajectory_2, C, 0.0 * C).to_CIF(
"X_frag_noise_aligned.cif"
)
Protein.from_XCS_trajectory(X_trajectory_3, C, 0.0 * C).to_CIF(
"X_frag_noise.cif"
)
return
def test_fragment_pair_rmsd(debug=False):
X, C, S = Protein("1SHG").to_XCS()
loss_pairs = LossFragmentPairRMSD()
X_noise = X + torch.randn_like(X)
rmsd, mask_ij = loss_pairs(X, X, C)
rmsd_noised, mask_ij = loss_pairs(X_noise, X, C)
assert rmsd.mean() < 1e-2
assert rmsd_noised.mean() > 1.0
if debug:
from chroma.layers.structure import diffusion
noise = diffusion.DiffusionChainCov(complex_scaling=True)
X_noise = noise(X, C, t=0.6)
rmsd, mask_ij, X_pair_target, X_pair_mobile, X_pair_mobile_align = loss_pairs(
X_noise, X, C, return_coords=True
)
print(rmsd)
def _trajectory(X_pairs):
B, I, J, _, _ = list(X_pairs.shape)
X_pairs = X_pairs.reshape([B * I * J, -1, 4, 3])
X_trajectory = [X_t[None, ...] for X_t in X_pairs.unbind(0)]
return X_trajectory
C = torch.cat(
[torch.ones([1, loss_pairs.k]), 2 * torch.ones([1, loss_pairs.k])], -1
)
X_trajectory_1 = _trajectory(X_pair_target)
X_trajectory_2 = _trajectory(X_pair_mobile_align)
X_trajectory_3 = _trajectory(X_pair_mobile)
# Fight pymol confusion
index = 1579
X_trajectory_3 = [X_trajectory_1[index]] + X_trajectory_3
X_trajectory_2 = [X_trajectory_1[index]] + X_trajectory_2
X_trajectory_1 = [X_trajectory_1[index]] + X_trajectory_1
Protein.from_XCS_trajectory(X_trajectory_1, C, 0.0 * C).to_CIF(
"X_pair_target.cif"
)
Protein.from_XCS_trajectory(X_trajectory_2, C, 0.0 * C).to_CIF(
"X_pair_noise_aligned.cif"
)
Protein.from_XCS_trajectory(X_trajectory_3, C, 0.0 * C).to_CIF(
"X_pair_noise.cif"
)
return
def test_neighborhood_rmsd(debug=False):
X, C, S = Protein("1SHG").to_XCS()
loss_nb = LossNeighborhoodRMSD()
X_noise = X + torch.randn_like(X)
rmsd, mask = loss_nb(X, X, C)
rmsd_noised, mask = loss_nb(X_noise, X, C)
assert rmsd.mean() < 1e-2
assert rmsd_noised.mean() > 1.0
if debug:
from chroma.layers.structure import diffusion
noise = diffusion.DiffusionChainCov(complex_scaling=True)
X_noise = noise(X, C, t=0.7)
rmsd, mask, X_nb_target, X_nb_mobile, X_nb_mobile_align = loss_nb(
X_noise, X, C, return_coords=True
)
print(rmsd, X_nb_target.shape)
def _trajectory(X_nbs):
B, I, _, _ = list(X_nbs.shape)
X_nbs = X_nbs.reshape([B * I, -1, 4, 3])
X_trajectory = [X_t[None, ...] for X_t in X_nbs.unbind(0)]
return X_trajectory
C = torch.ones([1, X_nb_target.shape[2] // 4])
X_trajectory_1 = _trajectory(X_nb_target)
X_trajectory_2 = _trajectory(X_nb_mobile_align)
X_trajectory_3 = _trajectory(X_nb_mobile)
# Fight pymol confusion
index = 10
X_trajectory_3 = [X_trajectory_1[index]] + X_trajectory_3
X_trajectory_2 = [X_trajectory_1[index]] + X_trajectory_2
X_trajectory_1 = [X_trajectory_1[index]] + X_trajectory_1
Protein.from_XCS_trajectory(X_trajectory_1, C, 0.0 * C).to_CIF(
"X_nb_target.cif"
)
Protein.from_XCS_trajectory(X_trajectory_2, C, 0.0 * C).to_CIF(
"X_nb_noise_aligned.cif"
)
Protein.from_XCS_trajectory(X_trajectory_3, C, 0.0 * C).to_CIF("X_nb_noise.cif")
return