Spaces:
Sleeping
Sleeping
File size: 3,258 Bytes
ce7bf5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 |
from unittest import TestCase
import numpy as np
import pytest
import torch
from chroma import constants
from chroma.data import Protein
from chroma.layers.structure import backbone, sidechain
class TestSideChain(TestCase):
def setUp(self):
self.builder = sidechain.SideChainBuilder()
self.chi_angles = sidechain.ChiAngles()
self.rmsd_loss = sidechain.LossSideChainRMSD()
self.clash_loss = sidechain.LossSidechainClashes()
self.frame_loss = sidechain.LossFrameAlignedGraph(distance_eps=1e-9)
self.distance_loss = sidechain.LossAllAtomDistances()
self.frame_builder = sidechain.AllAtomFrameBuilder()
pdb_id = "1SHG"
self.X, self.C, self.S = Protein(pdb_id).to_XCS(all_atom=True)
def test_chi_cartesian_round_trip(self):
X, C, S = self.X, self.C, self.S
X_bb = X[:, :, :4, :]
chi, mask_chi = self.chi_angles(X, C, S)
X_reference, mask_X = self.builder(X_bb, C, S, chi)
# Test round trip processing
chi_direct, _ = self.chi_angles(X_reference, C, S)
X_cycle, _ = self.builder(X_bb, C, S, chi_direct)
chi_cycle, _ = self.chi_angles(X_cycle, C, S)
_embed = lambda a: torch.stack([torch.cos(a), torch.sin(a)], -1)
self.assertTrue(torch.allclose(X_reference, X_cycle, atol=1e-1))
self.assertTrue(torch.allclose(_embed(chi), _embed(chi_cycle), atol=1e-2))
loss = self.rmsd_loss(X, X_cycle, C, S)
loss = self.clash_loss(X_cycle, C, S)
def test_integration(self):
num_letters = 20
chi = np.pi * torch.rand([1, num_letters, 4])
X_bb = backbone.ProteinBackbone(num_letters, init_state="beta")()
S = torch.arange(num_letters).unsqueeze(0)
C = torch.ones_like(S)
X, mask_X = self.builder(X_bb, C, S, chi)
chi, mask_chi = self.chi_angles(X, C, S)
self.assertTrue(
np.allclose(
mask_X.sum([-1, -2]).data.numpy(), np.asarray(constants.AA20_NUM_ATOMS)
)
)
self.assertTrue(
np.allclose(
mask_chi.sum(-1).data.numpy(), np.asarray(constants.AA20_NUM_CHI)
)
)
def test_frame_builder_round_trip(self):
X, C, S = self.X, self.C, self.S
x, q, chi = self.frame_builder.inverse(X, C, S)
X_cycle, mask_atoms = self.frame_builder(x, q, chi, C, S)
x = x + torch.randn_like(x) * 10.0
q = q + torch.randn_like(q) * 2.0
X_perturb, mask_atoms = self.frame_builder(x, q, chi, C, S)
mask = (C > 0).float()
_loss = lambda loss: (loss * mask).sum() / mask.sum()
loss_cycle_avg = _loss(self.frame_loss(X, X_cycle, C, S))
loss_perturb_avg = _loss(self.frame_loss(X, X_perturb, C, S))
print(loss_cycle_avg, loss_perturb_avg)
self.assertTrue(loss_cycle_avg.item() < 1.0)
self.assertTrue(loss_perturb_avg.item() > 1.0)
loss_cycle_avg = _loss(self.distance_loss(X, X_cycle, C, S))
loss_perturb_avg = _loss(self.distance_loss(X, X_perturb, C, S))
print(loss_cycle_avg, loss_perturb_avg)
self.assertTrue(loss_cycle_avg.item() < 1.0)
self.assertTrue(loss_perturb_avg.item() > 1.0)
|