Spaces:
Sleeping
Sleeping
from unittest import TestCase | |
import numpy as np | |
import pytest | |
import torch | |
from chroma import constants | |
from chroma.data import Protein | |
from chroma.layers.structure import backbone, sidechain | |
class TestSideChain(TestCase): | |
def setUp(self): | |
self.builder = sidechain.SideChainBuilder() | |
self.chi_angles = sidechain.ChiAngles() | |
self.rmsd_loss = sidechain.LossSideChainRMSD() | |
self.clash_loss = sidechain.LossSidechainClashes() | |
self.frame_loss = sidechain.LossFrameAlignedGraph(distance_eps=1e-9) | |
self.distance_loss = sidechain.LossAllAtomDistances() | |
self.frame_builder = sidechain.AllAtomFrameBuilder() | |
pdb_id = "1SHG" | |
self.X, self.C, self.S = Protein(pdb_id).to_XCS(all_atom=True) | |
def test_chi_cartesian_round_trip(self): | |
X, C, S = self.X, self.C, self.S | |
X_bb = X[:, :, :4, :] | |
chi, mask_chi = self.chi_angles(X, C, S) | |
X_reference, mask_X = self.builder(X_bb, C, S, chi) | |
# Test round trip processing | |
chi_direct, _ = self.chi_angles(X_reference, C, S) | |
X_cycle, _ = self.builder(X_bb, C, S, chi_direct) | |
chi_cycle, _ = self.chi_angles(X_cycle, C, S) | |
_embed = lambda a: torch.stack([torch.cos(a), torch.sin(a)], -1) | |
self.assertTrue(torch.allclose(X_reference, X_cycle, atol=1e-1)) | |
self.assertTrue(torch.allclose(_embed(chi), _embed(chi_cycle), atol=1e-2)) | |
loss = self.rmsd_loss(X, X_cycle, C, S) | |
loss = self.clash_loss(X_cycle, C, S) | |
def test_integration(self): | |
num_letters = 20 | |
chi = np.pi * torch.rand([1, num_letters, 4]) | |
X_bb = backbone.ProteinBackbone(num_letters, init_state="beta")() | |
S = torch.arange(num_letters).unsqueeze(0) | |
C = torch.ones_like(S) | |
X, mask_X = self.builder(X_bb, C, S, chi) | |
chi, mask_chi = self.chi_angles(X, C, S) | |
self.assertTrue( | |
np.allclose( | |
mask_X.sum([-1, -2]).data.numpy(), np.asarray(constants.AA20_NUM_ATOMS) | |
) | |
) | |
self.assertTrue( | |
np.allclose( | |
mask_chi.sum(-1).data.numpy(), np.asarray(constants.AA20_NUM_CHI) | |
) | |
) | |
def test_frame_builder_round_trip(self): | |
X, C, S = self.X, self.C, self.S | |
x, q, chi = self.frame_builder.inverse(X, C, S) | |
X_cycle, mask_atoms = self.frame_builder(x, q, chi, C, S) | |
x = x + torch.randn_like(x) * 10.0 | |
q = q + torch.randn_like(q) * 2.0 | |
X_perturb, mask_atoms = self.frame_builder(x, q, chi, C, S) | |
mask = (C > 0).float() | |
_loss = lambda loss: (loss * mask).sum() / mask.sum() | |
loss_cycle_avg = _loss(self.frame_loss(X, X_cycle, C, S)) | |
loss_perturb_avg = _loss(self.frame_loss(X, X_perturb, C, S)) | |
print(loss_cycle_avg, loss_perturb_avg) | |
self.assertTrue(loss_cycle_avg.item() < 1.0) | |
self.assertTrue(loss_perturb_avg.item() > 1.0) | |
loss_cycle_avg = _loss(self.distance_loss(X, X_cycle, C, S)) | |
loss_perturb_avg = _loss(self.distance_loss(X, X_perturb, C, S)) | |
print(loss_cycle_avg, loss_perturb_avg) | |
self.assertTrue(loss_cycle_avg.item() < 1.0) | |
self.assertTrue(loss_perturb_avg.item() > 1.0) | |