Hukuna's picture
Upload 275 files
e9e75df verified
raw
history blame
3.26 kB
from unittest import TestCase
import numpy as np
import pytest
import torch
from chroma import constants
from chroma.data import Protein
from chroma.layers.structure import backbone, sidechain
class TestSideChain(TestCase):
def setUp(self):
self.builder = sidechain.SideChainBuilder()
self.chi_angles = sidechain.ChiAngles()
self.rmsd_loss = sidechain.LossSideChainRMSD()
self.clash_loss = sidechain.LossSidechainClashes()
self.frame_loss = sidechain.LossFrameAlignedGraph(distance_eps=1e-9)
self.distance_loss = sidechain.LossAllAtomDistances()
self.frame_builder = sidechain.AllAtomFrameBuilder()
pdb_id = "1SHG"
self.X, self.C, self.S = Protein(pdb_id).to_XCS(all_atom=True)
def test_chi_cartesian_round_trip(self):
X, C, S = self.X, self.C, self.S
X_bb = X[:, :, :4, :]
chi, mask_chi = self.chi_angles(X, C, S)
X_reference, mask_X = self.builder(X_bb, C, S, chi)
# Test round trip processing
chi_direct, _ = self.chi_angles(X_reference, C, S)
X_cycle, _ = self.builder(X_bb, C, S, chi_direct)
chi_cycle, _ = self.chi_angles(X_cycle, C, S)
_embed = lambda a: torch.stack([torch.cos(a), torch.sin(a)], -1)
self.assertTrue(torch.allclose(X_reference, X_cycle, atol=1e-1))
self.assertTrue(torch.allclose(_embed(chi), _embed(chi_cycle), atol=1e-2))
loss = self.rmsd_loss(X, X_cycle, C, S)
loss = self.clash_loss(X_cycle, C, S)
def test_integration(self):
num_letters = 20
chi = np.pi * torch.rand([1, num_letters, 4])
X_bb = backbone.ProteinBackbone(num_letters, init_state="beta")()
S = torch.arange(num_letters).unsqueeze(0)
C = torch.ones_like(S)
X, mask_X = self.builder(X_bb, C, S, chi)
chi, mask_chi = self.chi_angles(X, C, S)
self.assertTrue(
np.allclose(
mask_X.sum([-1, -2]).data.numpy(), np.asarray(constants.AA20_NUM_ATOMS)
)
)
self.assertTrue(
np.allclose(
mask_chi.sum(-1).data.numpy(), np.asarray(constants.AA20_NUM_CHI)
)
)
def test_frame_builder_round_trip(self):
X, C, S = self.X, self.C, self.S
x, q, chi = self.frame_builder.inverse(X, C, S)
X_cycle, mask_atoms = self.frame_builder(x, q, chi, C, S)
x = x + torch.randn_like(x) * 10.0
q = q + torch.randn_like(q) * 2.0
X_perturb, mask_atoms = self.frame_builder(x, q, chi, C, S)
mask = (C > 0).float()
_loss = lambda loss: (loss * mask).sum() / mask.sum()
loss_cycle_avg = _loss(self.frame_loss(X, X_cycle, C, S))
loss_perturb_avg = _loss(self.frame_loss(X, X_perturb, C, S))
print(loss_cycle_avg, loss_perturb_avg)
self.assertTrue(loss_cycle_avg.item() < 1.0)
self.assertTrue(loss_perturb_avg.item() > 1.0)
loss_cycle_avg = _loss(self.distance_loss(X, X_cycle, C, S))
loss_perturb_avg = _loss(self.distance_loss(X, X_perturb, C, S))
print(loss_cycle_avg, loss_perturb_avg)
self.assertTrue(loss_cycle_avg.item() < 1.0)
self.assertTrue(loss_perturb_avg.item() > 1.0)