Spaces:
Sleeping
Sleeping
from pathlib import Path | |
from unittest import SkipTest, TestCase | |
import numpy as np | |
import pytest | |
import torch | |
import chroma | |
from chroma.data import Protein | |
from chroma.layers.structure.backbone import impute_masked_X | |
from chroma.layers.structure.mvn import ( | |
BackboneMVNGlobular, | |
BackboneMVNResidueGas, | |
ConditionalBackboneMVNGlobular, | |
) | |
from chroma.layers.structure.rmsd import BackboneRMSD | |
def noise(request): | |
covariance_model = request.param | |
if covariance_model in ["brownian", "globular"]: | |
return BackboneMVNGlobular( | |
covariance_model=covariance_model, complex_scaling=True, | |
) | |
else: | |
return BackboneMVNResidueGas( | |
covariance_model=covariance_model, complex_scaling=True, | |
) | |
def XCS(request): | |
xcs_type = request.param | |
if xcs_type == "real": | |
repo = Path(chroma.__file__).parent.parent | |
test_cif = str(Path(repo, "tests", "resources", "6wgl.cif")) | |
X, C, S = Protein(test_cif).to_XCS() | |
else: | |
num_batch, num_residues = 5, 100 | |
X = 10 * torch.randn([num_batch, num_residues * 4, 3]) | |
C = torch.ones([num_batch, num_residues]) | |
S = C.clone() | |
return X, C, S | |
def test_full_covariance_and_sqrt_covariance_computation(): | |
num_batch, num_residues = 1, 100 | |
X = 10 * torch.randn([num_batch, num_residues, 4, 3]) | |
C = torch.ones([num_batch, num_residues]) | |
S = C.clone() | |
D = torch.randint(low=0, high=2, size=(C.size())) | |
# Fill in missing pieces | |
X = impute_masked_X(X, C) | |
C = torch.abs(C) | |
mvn = BackboneMVNGlobular(covariance_model="globular", complex_scaling=True,) | |
cmvn = ConditionalBackboneMVNGlobular( | |
covariance_model="globular", complex_scaling=True, X=X, C=C, D=D | |
) | |
# Test R | |
Z = torch.randn_like(X).reshape(X.shape[0], -1, 3) | |
RZ_mvn_implicit = mvn._multiply_R(Z, C) | |
RZ_mvn_dense = (cmvn.R @ Z).reshape(RZ_mvn_implicit.shape) | |
assert torch.allclose(RZ_mvn_implicit, RZ_mvn_dense, atol=1e-2) | |
# Test RRt | |
RRt_Z_implicit = mvn.multiply_covariance(Z, C) | |
RRt_Z_dense = cmvn.RRt @ Z | |
assert torch.allclose(RRt_Z_implicit, RRt_Z_dense, atol=1e-2) | |
def test_invertibility_R(noise, XCS): | |
"""Test invertibility of the covariance square root.""" | |
X, C, S = XCS | |
X = X.reshape([X.shape[0], -1, 3]) | |
Ri_X = noise._multiply_R_inverse(X, C) | |
R_Ri_X = noise._multiply_R(Ri_X, C) | |
Rti_X = noise._multiply_R_inverse_transpose(X, C) | |
Rt_Rti_X = noise._multiply_R_transpose(Rti_X, C) | |
X = X.reshape(X.shape[0], C.shape[1], -1, 3) | |
Ri_X = Ri_X.reshape(X.shape) | |
R_Ri_X = R_Ri_X.reshape(X.shape) | |
Rti_X = Rti_X.reshape(X.shape) | |
Rt_Rti_X = Rt_Rti_X.reshape(X.shape) | |
if False: | |
from chroma.layers.structure.diffusion import _debug_viz_XZC | |
_debug_viz_XZC(X, Ri_X, C) | |
assert torch.allclose(X, R_Ri_X, atol=1e-2) | |
assert torch.allclose(X, Rt_Rti_X, atol=1e-2) | |
assert not torch.allclose(Ri_X, R_Ri_X, atol=1e-2) | |
assert not torch.allclose(Rti_X, Rt_Rti_X, atol=1e-2) | |
def test_invertibility_covariance(noise, XCS, debug=False): | |
"""Test invertibility of the covariance matrix. | |
Note: the covariance matrix is poorly conditioned for all but | |
the smallest systems, so for numerical verification the system needs | |
to be small with a large tolerance. | |
""" | |
X, C, S = XCS | |
# Cycle constraint | |
Ci_X = noise.multiply_inverse_covariance(X, C) | |
C_Ci_X = noise.multiply_covariance(Ci_X, C) | |
if debug and not torch.allclose(X, C_Ci_X, atol=1e-1): | |
from matplotlib import pyplot as plt | |
plt.figure() | |
plt.subplot(3, 1, 1) | |
plt.plot((X - C_Ci_X).data.numpy().flatten(), ".") | |
plt.subplot(3, 1, 2) | |
plt.plot(X.data.numpy().flatten()) | |
plt.subplot(3, 1, 3) | |
plt.plot(C.data.numpy().flatten()) | |
plt.savefig(f"test_icov.pdf") | |
assert torch.allclose(X, C_Ci_X, atol=1e-1) | |
assert not torch.allclose(Ci_X, C_Ci_X, atol=1e-1) | |
def test_log_determinant(noise): | |
"""Test log determinant of the covariance matrix.""" | |
X, C, S = Protein("5imm").to_XCS() | |
X = X[0:1, ...] | |
C = C[0:1, ...] | |
C = torch.abs(C) | |
X = impute_masked_X(X, C) | |
if hasattr(noise, "covariance_model"): | |
# Use the conditional covariance model to build a dense RRt | |
cmvn = ConditionalBackboneMVNGlobular( | |
covariance_model=noise.covariance_model, | |
complex_scaling=noise.complex_scaling, | |
X=X, | |
C=C, | |
D=C.ne(0).float(), | |
) | |
R, RRt = cmvn._materialize_RRt(C) | |
R = R.data.numpy() | |
logdet_dense = 3.0 * np.linalg.slogdet(R)[1] | |
logdet = noise.log_determinant(C) | |
assert logdet.item() == pytest.approx(logdet_dense.item()) | |
def test_cmvn(noise): | |
if isinstance(noise, BackboneMVNResidueGas): | |
pass | |
else: | |
aligner = BackboneRMSD() | |
protein = Protein("1drf") | |
X, C, S = protein.to_XCS() | |
protein.sys.save_selection(gti=list(range(14)), selname="clamp") | |
cmvn = ConditionalBackboneMVNGlobular( | |
covariance_model=noise.covariance_model, | |
complex_scaling=noise.complex_scaling, | |
X=X, | |
C=C, | |
D=protein.get_mask("namesel clamp"), | |
) | |
X_sample = cmvn.sample() | |
_, rmsd = aligner.align(X, X_sample, protein.get_mask("namesel clamp")) | |
assert rmsd.item() < 1e-1 | |