Spaces:
Sleeping
Sleeping
from unittest import TestCase | |
import pytest | |
import torch | |
from chroma.layers.structure.backbone import ( | |
BackboneBuilder, | |
LossBackboneResidueDistance, | |
ProteinBackbone, | |
RigidTransform, | |
RigidTransformer, | |
) | |
class TestProteinBackbone(TestCase): | |
def test_cuda(self): | |
if torch.cuda.is_available(): | |
try: | |
protein_backbone = ProteinBackbone(1).cuda() | |
except Exception: | |
protein_backbone = None | |
self.assertTrue(protein_backbone is not None) | |
def test_sample(self): | |
protein_backbone = ProteinBackbone(1) | |
expected = torch.Tensor( | |
[ | |
[ | |
[ | |
[0.1331, -1.6303, -0.7377], | |
[0.0414, -0.1759, -0.8080], | |
[-0.3710, 0.4114, 0.5376], | |
[0.1965, 1.3947, 1.0081], | |
] | |
] | |
] | |
) | |
predicted = protein_backbone() | |
self.assertEqual((1, 1, 4, 3), predicted.shape) | |
self.assertTrue(torch.allclose(expected, predicted, rtol=1e-03)) | |
def test_random_init_backbone(self): | |
protein_backbone = ProteinBackbone(1, init_state="") | |
predicted = protein_backbone() | |
self.assertEqual((1, 1, 4, 3), predicted.shape) | |
def test_sample_cartesian(self): | |
protein_backbone = ProteinBackbone(1, use_internal_coords=False) | |
expected = torch.Tensor( | |
[ | |
[ | |
[ | |
[0.1331, -1.6303, -0.7377], | |
[0.0414, -0.1759, -0.8080], | |
[-0.3710, 0.4114, 0.5376], | |
[0.1965, 1.3947, 1.0081], | |
] | |
] | |
] | |
) | |
predicted = protein_backbone() | |
self.assertEqual((1, 1, 4, 3), predicted.shape) | |
self.assertTrue(torch.allclose(expected, predicted, rtol=1e-03)) | |
def test_initialized_sample(self): | |
torch.manual_seed(7) | |
input_x = torch.rand(1, 2, 4, 3) | |
predicted = ProteinBackbone(1, use_internal_coords=False, X_init=input_x)() | |
expected = torch.Tensor( | |
[ | |
[ | |
[ | |
[-5.3644e-07, -2.6469e-01, 1.4716e-01], | |
[1.2197e-01, -2.3073e-01, -8.6988e-02], | |
[-3.2784e-01, 1.6625e-01, -1.4673e-01], | |
[3.1635e-01, 3.9145e-01, 3.8886e-02], | |
], | |
[ | |
[-2.4808e-01, -2.5717e-01, -6.6959e-02], | |
[-1.7564e-01, 2.5689e-01, -4.3900e-01], | |
[4.3500e-01, -3.5570e-01, 3.7083e-01], | |
[-1.2175e-01, 2.9370e-01, 1.8280e-01], | |
], | |
] | |
] | |
) | |
self.assertTrue(torch.allclose(predicted, expected, rtol=1e-03)) | |
class TestRigidTransform(TestCase): | |
def test_sample(self): | |
# Default behavior should be identity transformation | |
rigid_transform = RigidTransform() | |
torch.manual_seed(7) | |
input_x = torch.rand(1, 1, 4, 3) | |
predicted = rigid_transform(input_x) | |
self.assertTrue(torch.allclose(predicted, input_x, rtol=1e-3)) | |
class TestRigidTransformer(TestCase): | |
def test_sample(self): | |
rigid_transformer = RigidTransformer(center_rotation=True, keep_centered=True) | |
input_x = torch.rand(1, 1, 4, 3) | |
mean_centered = input_x - torch.mean(input_x.reshape(1, -1, 3), axis=-2) | |
# Test Identity | |
no_translation = torch.zeros(1, 3) | |
identity_q = torch.Tensor([[1.0, 0, 0, 0]]) | |
predicted = rigid_transformer(input_x, no_translation, identity_q) | |
self.assertTrue(torch.allclose(predicted, mean_centered, rtol=1e-3)) | |
# Test Translation | |
x_translation = torch.Tensor([[1, 0, 0]]) | |
expected = mean_centered + x_translation | |
predicted = rigid_transformer(input_x, x_translation, identity_q) | |
self.assertTrue(torch.allclose(predicted, expected, rtol=1e-3)) | |
class TestBackboneBuilder(TestCase): | |
def test_sample(self): | |
phi_tensor = torch.Tensor([[-1.0472]]) | |
psi_tensor = torch.Tensor([[-0.7854]]) | |
backbone_builder = BackboneBuilder() | |
expected = torch.Tensor( | |
[ | |
[ | |
[ | |
[-1.2286, 0.2223, -1.2286], | |
[-1.3203, 1.6767, -1.2989], | |
[-1.7327, 2.2640, 0.0468], | |
[-1.1652, 3.2473, 0.5172], | |
] | |
] | |
] | |
) | |
predicted = backbone_builder(phi_tensor, psi_tensor) | |
self.assertTrue(torch.allclose(expected, predicted, rtol=1e-3)) | |
def test_custom_sample(self): | |
num_residues = 1 | |
phi_tensor = torch.Tensor([[-1.0472]]) | |
psi_tensor = torch.Tensor([[-0.7854]]) | |
backbone_builder = BackboneBuilder() | |
expected = torch.Tensor( | |
[ | |
[ | |
[ | |
[-1.2286, 0.2223, -1.2286], | |
[-1.3203, 1.6767, -1.2989], | |
[-1.7327, 2.2640, 0.0468], | |
[-1.1652, 3.2473, 0.5172], | |
] | |
] | |
] | |
) | |
predicted = backbone_builder(phi_tensor, psi_tensor) | |
lengths = torch.tensor( | |
[[backbone_builder.lengths[key] for key in ["C_N", "N_CA", "CA_C"]]], | |
dtype=torch.float32, | |
) | |
lengths = lengths.repeat(1, 1) # (1,3) | |
angles = torch.tensor( | |
[[backbone_builder.angles[key] for key in ["CA_C_N", "C_N_CA", "N_CA_C"]]], | |
dtype=torch.float32, | |
) | |
angles = angles.repeat(1, 1) # (1,3) | |
omega = backbone_builder.angles["omega"] * torch.ones(1, 1) # (1,1) | |
predicted = backbone_builder(phi_tensor, psi_tensor, omega, angles, lengths) | |
self.assertTrue(torch.allclose(expected, predicted, rtol=1e-3)) | |
lengths = torch.tensor( | |
[[backbone_builder.lengths[key] for key in ["C_N", "N_CA", "CA_C"]]], | |
dtype=torch.float32, | |
) | |
lengths = lengths.repeat(1, num_residues) # (1,3) | |
angles = torch.tensor( | |
[[backbone_builder.angles[key] for key in ["CA_C_N", "C_N_CA", "N_CA_C"]]], | |
dtype=torch.float32, | |
) | |
angles = angles.repeat(1, num_residues) # (1,3) | |
omega = backbone_builder.angles["omega"] * torch.ones(1, num_residues) # (1,1) | |
predicted = backbone_builder(phi_tensor, psi_tensor, omega, angles, lengths) | |
self.assertTrue(torch.allclose(expected, predicted, rtol=1e-3)) | |