Spaces:
Sleeping
Sleeping
File size: 5,612 Bytes
e9e75df |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
from pathlib import Path
from unittest import TestCase
import numpy as np
import torch
import torch.nn.functional as F
import chroma
from chroma.data import Protein
from chroma.layers.structure import geometry
class TestDistances(TestCase):
def test_sample(self):
distances = geometry.Distances()
torch.manual_seed(7)
input_x = torch.rand(1, 2, 4, 3)
dim = -2
predicted = distances(input_x, None, dim)
self.assertTrue(predicted.shape == (1, 2, 4, 4))
expected = torch.tensor(
[
[
[
[0.0316, 0.2681, 0.6169, 0.7371],
[0.2681, 0.0316, 0.6037, 0.6646],
[0.6169, 0.6037, 0.0316, 0.7079],
[0.7371, 0.6646, 0.7079, 0.0316],
],
[
[0.0316, 0.6395, 0.8179, 0.6187],
[0.6395, 0.0316, 1.1853, 0.6260],
[0.8179, 1.1853, 0.0316, 0.8764],
[0.6187, 0.6260, 0.8764, 0.0316],
],
]
]
)
self.assertTrue(torch.allclose(predicted, expected, rtol=1e-3))
class TestRotations(TestCase):
def setUp(self):
self.R = torch.tensor(
[
[
[0.9027011, -0.1829866, -0.3894183],
[-0.3146039, 0.3367128, -0.8874959],
[0.2935220, 0.9236560, 0.2463827],
],
[[1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]],
[
[-0.6638935, 0.6988353, 0.2662229],
[-0.6322795, -0.3344426, -0.6988353],
[-0.3993345, -0.6322795, 0.6638935],
],
]
)
self.q = torch.tensor(
[
[0.7883205, 0.5743704, -0.2165808, -0.0417398],
[1.0, 0.0, 0.0, 0.0],
[0.4079085, 0.0407909, 0.4079085, -0.815817],
]
)
def test_rotations_from_quaternions(self):
R_from_q = geometry.rotations_from_quaternions(self.q)
self.assertTrue(torch.allclose(self.R, R_from_q, atol=1e-3))
def test_quaternions_from_rotations(self):
q_from_R = geometry.quaternions_from_rotations(self.R, eps=0.0)
self.assertTrue(torch.allclose(self.q, q_from_R, atol=1e-3))
def test_round_trip(self):
R_from_q = geometry.rotations_from_quaternions(self.q)
q_round_trip = geometry.quaternions_from_rotations(R_from_q, eps=0.0)
R_from_round_trip = geometry.rotations_from_quaternions(q_round_trip)
self.assertTrue(torch.allclose(self.q, q_round_trip, atol=1e-3))
self.assertTrue(torch.allclose(self.R, R_from_round_trip, atol=1e-3))
class TestExtendAtoms(TestCase):
def test_extend_atoms_round_trip(self):
# Test cycle-consistency of geometry measurement and building routines
num_batch, num_residues = 10, 30
X1, X2, X3 = torch.randn([num_batch, num_residues, 3, 3]).unbind(-1)
L = torch.exp(torch.randn([num_batch, num_residues])) + 1.0
A = np.pi * torch.sigmoid(torch.randn([num_batch, num_residues]))
D = np.pi * torch.randn([num_batch, num_residues])
X4 = geometry.extend_atoms(X1, X2, X3, L, A, D, distance_eps=1e-6)
L_recover = geometry.lengths(X3, X4, distance_eps=0.0)
A_recover = geometry.angles(X2, X3, X4, distance_eps=0.0)
D_recover = geometry.dihedrals(X1, X2, X3, X4, distance_eps=0.0)
_embed = lambda a: torch.stack([torch.cos(a), torch.sin(a)], -1)
self.assertTrue(torch.allclose(L, L_recover, atol=1e-2))
self.assertTrue(torch.allclose(A, A_recover, atol=1e-2))
self.assertTrue(torch.allclose(_embed(D), _embed(D_recover), atol=1e-2))
return
class TestVirtualAtomsCA(TestCase):
def test_atom_placement(self):
# Load test case
file_cif = str(
Path(Path(chroma.__file__).parent.parent, "tests", "resources", "5jg9.cif",)
)
X, C, S = Protein(file_cif).to_XCS()
for v_type in ["cbeta", "dicons"]:
# Place atoms
atom_placer = geometry.VirtualAtomsCA(virtual_type=v_type)
X_virtual = atom_placer(X, C)
# DEBUG: Sanity check is useful for testing
# geometry.debug_pymol_virtual_atoms(X, X_virtual, 'test_5jg9.pml')
# Test that generated angles are correct
X_N, X_CA, X_C, X_O = X.unbind(2)
bonds = torch.norm(X_virtual - X_CA, dim=-1)
angles = geometry.angles(
X_N, X_CA, X_virtual, distance_eps=1e-6, degrees=True
)
dihedrals = geometry.dihedrals(
X_C, X_N, X_CA, X_virtual, distance_eps=1e-6, degrees=True
)
bond_t, angle_t, dihedral_t = atom_placer.geometry()
mask = (C > 0).type(torch.float32)
bond_error = mask * (bonds - bond_t)
angle_error = mask * (angles - angle_t)
dihedral_error = mask * (dihedrals - dihedral_t)
self.assertTrue(
torch.allclose(bond_error, torch.zeros_like(bond_error), atol=1e-2)
)
self.assertTrue(
torch.allclose(angle_error, torch.zeros_like(angle_error), atol=1e-2)
)
self.assertTrue(
torch.allclose(
dihedral_error, torch.zeros_like(dihedral_error), atol=1e-2
)
)
|