Hukuna's picture
Upload 221 files
ce7bf5b verified
raw
history blame
5.61 kB
from pathlib import Path
from unittest import TestCase
import numpy as np
import torch
import torch.nn.functional as F
import chroma
from chroma.data import Protein
from chroma.layers.structure import geometry
class TestDistances(TestCase):
def test_sample(self):
distances = geometry.Distances()
torch.manual_seed(7)
input_x = torch.rand(1, 2, 4, 3)
dim = -2
predicted = distances(input_x, None, dim)
self.assertTrue(predicted.shape == (1, 2, 4, 4))
expected = torch.tensor(
[
[
[
[0.0316, 0.2681, 0.6169, 0.7371],
[0.2681, 0.0316, 0.6037, 0.6646],
[0.6169, 0.6037, 0.0316, 0.7079],
[0.7371, 0.6646, 0.7079, 0.0316],
],
[
[0.0316, 0.6395, 0.8179, 0.6187],
[0.6395, 0.0316, 1.1853, 0.6260],
[0.8179, 1.1853, 0.0316, 0.8764],
[0.6187, 0.6260, 0.8764, 0.0316],
],
]
]
)
self.assertTrue(torch.allclose(predicted, expected, rtol=1e-3))
class TestRotations(TestCase):
def setUp(self):
self.R = torch.tensor(
[
[
[0.9027011, -0.1829866, -0.3894183],
[-0.3146039, 0.3367128, -0.8874959],
[0.2935220, 0.9236560, 0.2463827],
],
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
[
[-0.6638935, 0.6988353, 0.2662229],
[-0.6322795, -0.3344426, -0.6988353],
[-0.3993345, -0.6322795, 0.6638935],
],
]
)
self.q = torch.tensor(
[
[0.7883205, 0.5743704, -0.2165808, -0.0417398],
[1.0, 0.0, 0.0, 0.0],
[0.4079085, 0.0407909, 0.4079085, -0.815817],
]
)
def test_rotations_from_quaternions(self):
R_from_q = geometry.rotations_from_quaternions(self.q)
self.assertTrue(torch.allclose(self.R, R_from_q, atol=1e-3))
def test_quaternions_from_rotations(self):
q_from_R = geometry.quaternions_from_rotations(self.R, eps=0.0)
self.assertTrue(torch.allclose(self.q, q_from_R, atol=1e-3))
def test_round_trip(self):
R_from_q = geometry.rotations_from_quaternions(self.q)
q_round_trip = geometry.quaternions_from_rotations(R_from_q, eps=0.0)
R_from_round_trip = geometry.rotations_from_quaternions(q_round_trip)
self.assertTrue(torch.allclose(self.q, q_round_trip, atol=1e-3))
self.assertTrue(torch.allclose(self.R, R_from_round_trip, atol=1e-3))
class TestExtendAtoms(TestCase):
def test_extend_atoms_round_trip(self):
# Test cycle-consistency of geometry measurement and building routines
num_batch, num_residues = 10, 30
X1, X2, X3 = torch.randn([num_batch, num_residues, 3, 3]).unbind(-1)
L = torch.exp(torch.randn([num_batch, num_residues])) + 1.0
A = np.pi * torch.sigmoid(torch.randn([num_batch, num_residues]))
D = np.pi * torch.randn([num_batch, num_residues])
X4 = geometry.extend_atoms(X1, X2, X3, L, A, D, distance_eps=1e-6)
L_recover = geometry.lengths(X3, X4, distance_eps=0.0)
A_recover = geometry.angles(X2, X3, X4, distance_eps=0.0)
D_recover = geometry.dihedrals(X1, X2, X3, X4, distance_eps=0.0)
_embed = lambda a: torch.stack([torch.cos(a), torch.sin(a)], -1)
self.assertTrue(torch.allclose(L, L_recover, atol=1e-2))
self.assertTrue(torch.allclose(A, A_recover, atol=1e-2))
self.assertTrue(torch.allclose(_embed(D), _embed(D_recover), atol=1e-2))
return
class TestVirtualAtomsCA(TestCase):
def test_atom_placement(self):
# Load test case
file_cif = str(
Path(Path(chroma.__file__).parent.parent, "tests", "resources", "5jg9.cif",)
)
X, C, S = Protein(file_cif).to_XCS()
for v_type in ["cbeta", "dicons"]:
# Place atoms
atom_placer = geometry.VirtualAtomsCA(virtual_type=v_type)
X_virtual = atom_placer(X, C)
# DEBUG: Sanity check is useful for testing
# geometry.debug_pymol_virtual_atoms(X, X_virtual, 'test_5jg9.pml')
# Test that generated angles are correct
X_N, X_CA, X_C, X_O = X.unbind(2)
bonds = torch.norm(X_virtual - X_CA, dim=-1)
angles = geometry.angles(
X_N, X_CA, X_virtual, distance_eps=1e-6, degrees=True
)
dihedrals = geometry.dihedrals(
X_C, X_N, X_CA, X_virtual, distance_eps=1e-6, degrees=True
)
bond_t, angle_t, dihedral_t = atom_placer.geometry()
mask = (C > 0).type(torch.float32)
bond_error = mask * (bonds - bond_t)
angle_error = mask * (angles - angle_t)
dihedral_error = mask * (dihedrals - dihedral_t)
self.assertTrue(
torch.allclose(bond_error, torch.zeros_like(bond_error), atol=1e-2)
)
self.assertTrue(
torch.allclose(angle_error, torch.zeros_like(angle_error), atol=1e-2)
)
self.assertTrue(
torch.allclose(
dihedral_error, torch.zeros_like(dihedral_error), atol=1e-2
)
)