Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from unittest import TestCase | |
import numpy as np | |
import torch | |
import torch.nn.functional as F | |
import chroma | |
from chroma.data import Protein | |
from chroma.layers.structure import geometry | |
class TestDistances(TestCase): | |
def test_sample(self): | |
distances = geometry.Distances() | |
torch.manual_seed(7) | |
input_x = torch.rand(1, 2, 4, 3) | |
dim = -2 | |
predicted = distances(input_x, None, dim) | |
self.assertTrue(predicted.shape == (1, 2, 4, 4)) | |
expected = torch.tensor( | |
[ | |
[ | |
[ | |
[0.0316, 0.2681, 0.6169, 0.7371], | |
[0.2681, 0.0316, 0.6037, 0.6646], | |
[0.6169, 0.6037, 0.0316, 0.7079], | |
[0.7371, 0.6646, 0.7079, 0.0316], | |
], | |
[ | |
[0.0316, 0.6395, 0.8179, 0.6187], | |
[0.6395, 0.0316, 1.1853, 0.6260], | |
[0.8179, 1.1853, 0.0316, 0.8764], | |
[0.6187, 0.6260, 0.8764, 0.0316], | |
], | |
] | |
] | |
) | |
self.assertTrue(torch.allclose(predicted, expected, rtol=1e-3)) | |
class TestRotations(TestCase): | |
def setUp(self): | |
self.R = torch.tensor( | |
[ | |
[ | |
[0.9027011, -0.1829866, -0.3894183], | |
[-0.3146039, 0.3367128, -0.8874959], | |
[0.2935220, 0.9236560, 0.2463827], | |
], | |
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]], | |
[ | |
[-0.6638935, 0.6988353, 0.2662229], | |
[-0.6322795, -0.3344426, -0.6988353], | |
[-0.3993345, -0.6322795, 0.6638935], | |
], | |
] | |
) | |
self.q = torch.tensor( | |
[ | |
[0.7883205, 0.5743704, -0.2165808, -0.0417398], | |
[1.0, 0.0, 0.0, 0.0], | |
[0.4079085, 0.0407909, 0.4079085, -0.815817], | |
] | |
) | |
def test_rotations_from_quaternions(self): | |
R_from_q = geometry.rotations_from_quaternions(self.q) | |
self.assertTrue(torch.allclose(self.R, R_from_q, atol=1e-3)) | |
def test_quaternions_from_rotations(self): | |
q_from_R = geometry.quaternions_from_rotations(self.R, eps=0.0) | |
self.assertTrue(torch.allclose(self.q, q_from_R, atol=1e-3)) | |
def test_round_trip(self): | |
R_from_q = geometry.rotations_from_quaternions(self.q) | |
q_round_trip = geometry.quaternions_from_rotations(R_from_q, eps=0.0) | |
R_from_round_trip = geometry.rotations_from_quaternions(q_round_trip) | |
self.assertTrue(torch.allclose(self.q, q_round_trip, atol=1e-3)) | |
self.assertTrue(torch.allclose(self.R, R_from_round_trip, atol=1e-3)) | |
class TestExtendAtoms(TestCase): | |
def test_extend_atoms_round_trip(self): | |
# Test cycle-consistency of geometry measurement and building routines | |
num_batch, num_residues = 10, 30 | |
X1, X2, X3 = torch.randn([num_batch, num_residues, 3, 3]).unbind(-1) | |
L = torch.exp(torch.randn([num_batch, num_residues])) + 1.0 | |
A = np.pi * torch.sigmoid(torch.randn([num_batch, num_residues])) | |
D = np.pi * torch.randn([num_batch, num_residues]) | |
X4 = geometry.extend_atoms(X1, X2, X3, L, A, D, distance_eps=1e-6) | |
L_recover = geometry.lengths(X3, X4, distance_eps=0.0) | |
A_recover = geometry.angles(X2, X3, X4, distance_eps=0.0) | |
D_recover = geometry.dihedrals(X1, X2, X3, X4, distance_eps=0.0) | |
_embed = lambda a: torch.stack([torch.cos(a), torch.sin(a)], -1) | |
self.assertTrue(torch.allclose(L, L_recover, atol=1e-2)) | |
self.assertTrue(torch.allclose(A, A_recover, atol=1e-2)) | |
self.assertTrue(torch.allclose(_embed(D), _embed(D_recover), atol=1e-2)) | |
return | |
class TestVirtualAtomsCA(TestCase): | |
def test_atom_placement(self): | |
# Load test case | |
file_cif = str( | |
Path(Path(chroma.__file__).parent.parent, "tests", "resources", "5jg9.cif",) | |
) | |
X, C, S = Protein(file_cif).to_XCS() | |
for v_type in ["cbeta", "dicons"]: | |
# Place atoms | |
atom_placer = geometry.VirtualAtomsCA(virtual_type=v_type) | |
X_virtual = atom_placer(X, C) | |
# DEBUG: Sanity check is useful for testing | |
# geometry.debug_pymol_virtual_atoms(X, X_virtual, 'test_5jg9.pml') | |
# Test that generated angles are correct | |
X_N, X_CA, X_C, X_O = X.unbind(2) | |
bonds = torch.norm(X_virtual - X_CA, dim=-1) | |
angles = geometry.angles( | |
X_N, X_CA, X_virtual, distance_eps=1e-6, degrees=True | |
) | |
dihedrals = geometry.dihedrals( | |
X_C, X_N, X_CA, X_virtual, distance_eps=1e-6, degrees=True | |
) | |
bond_t, angle_t, dihedral_t = atom_placer.geometry() | |
mask = (C > 0).type(torch.float32) | |
bond_error = mask * (bonds - bond_t) | |
angle_error = mask * (angles - angle_t) | |
dihedral_error = mask * (dihedrals - dihedral_t) | |
self.assertTrue( | |
torch.allclose(bond_error, torch.zeros_like(bond_error), atol=1e-2) | |
) | |
self.assertTrue( | |
torch.allclose(angle_error, torch.zeros_like(angle_error), atol=1e-2) | |
) | |
self.assertTrue( | |
torch.allclose( | |
dihedral_error, torch.zeros_like(dihedral_error), atol=1e-2 | |
) | |
) | |