File size: 4,836 Bytes
e9e75df
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
import pytest
import torch

from chroma.data import Protein
from chroma.layers.structure import backbone, protein_graph
from chroma.models.graph_backbone import GraphBackbone


def test_denoiser(dim_nodes=32, dim_edges=32):
    X, C, S = Protein("1SHG").to_XCS()
    model = GraphBackbone(dim_nodes=dim_nodes, dim_edges=dim_edges)

    # check if denoiser is working as expected
    model.CA_dist_scaling = False
    X0 = model.denoise(X, C, 0.0)
    assert X0.shape == X.shape

    # test if prediction_type="scale" is working
    model = GraphBackbone(prediction_type="scale")
    X0 = model.denoise(X, C, 0.0)
    assert torch.allclose(X0, X, rtol=1e-2)

    # check if CA_dist scale is working as expected
    model.CA_dist_scaling = False
    X0 = model.denoise(0.25 * X, C, 1e-4)
    assert X0.shape == X.shape
    # assert model._D_backbone_CA(X0, C).min().item() < model.min_CA_bb_distance


@pytest.mark.parametrize("t", [0.1, 0.7, 1.0])
def test_equivariance_denoiser(t, dim_nodes=32, dim_edges=32, seed=10, debug=False):
    X = backbone.ProteinBackbone(num_batch=1, num_residues=20, init_state="alpha")()
    C = torch.ones(X.shape[:2])
    S = torch.zeros_like(C).long()

    model = GraphBackbone(dim_nodes=dim_nodes, dim_edges=dim_edges).eval()

    # Test rotation equivariance
    transformer = backbone.RigidTransformer(center_rotation=False)
    q_transform = torch.Tensor([0.0, 0.1, -1.2, 0.5]).unsqueeze(0)
    dX_transform = torch.Tensor([-3.0, 30.0, 7.0]).unsqueeze(0)
    _transform = lambda X_input: transformer(X_input, dX_transform, q_transform)

    # Add noise
    X_noised = model.noise_perturb(X, C, t=t)
    X_noised_transform = _transform(X_noised)

    # Synchronize random seeds for random graph generation
    torch.manual_seed(seed)
    X_denoised = model.denoise(X_noised, C, t=t)
    X_denoised_transform = _transform(X_denoised)
    torch.manual_seed(seed)
    X_transform_denoised = model.denoise(X_noised_transform, C, t=t)

    if debug:
        print((X_denoised_transform - X_transform_denoised).abs().max())

        Protein(X, C, S).to_CIF("X_denoised.cif")
        Protein(X_denoised, C, S).to_CIF("X_denoised.cif")
        Protein(X_denoised_transform, C, S).to_CIF("X_denoised_transform.cif")
        Protein(X_transform_denoised, C, S).to_CIF("X_transform_denoised.cif")

    # The oxygen atom of the final carboxy terminus residue in each chain
    # is disambiguated via zero-padding (non-equivariant), so it can be up to \
    # ~1 angstrom off depending on global pose
    assert torch.allclose(
        X_denoised_transform[:, :-1, :, :],
        X_transform_denoised[:, :-1, :, :],
        atol=1e-1,
    )
    # Nevertheless at this adjusted tolerance we are equivariant
    assert torch.allclose(X_denoised_transform, X_transform_denoised, atol=3.0)
    assert not torch.allclose(X_denoised, X_transform_denoised, atol=1e-1)


@pytest.mark.parametrize("num_transform_weights", [1, 2, 3])
@pytest.mark.parametrize("dim_nodes", [32])
@pytest.mark.parametrize("dim_edges", [32])
def test_equivariance_graph_update(
    num_transform_weights, dim_nodes, dim_edges, output_structures=False
):
    torch.manual_seed(10.0)

    # Initialize layers
    bb_update = backbone.GraphBackboneUpdate(
        dim_nodes=dim_nodes,
        dim_edges=dim_edges,
        method="neighbor_global_affine",
        num_transform_weights=num_transform_weights,
    ).eval()
    pg = protein_graph.ProteinFeatureGraph(
        dim_nodes=dim_nodes, dim_edges=dim_edges, num_neighbors=5
    )

    # Test rotation equivariance
    transformer = backbone.RigidTransformer(center_rotation=False)
    q_rotate = torch.Tensor([0.0, 0.1, -1.2, 0.5]).unsqueeze(0)
    dX_rotate = torch.Tensor([0.0, 0.0, 0.0]).unsqueeze(0)
    _rotate = lambda X_input: transformer(X_input, dX_rotate, q_rotate)

    # Load test structure and canonicalize
    X, C, S = Protein("1qys").to_XCS()
    R, t, _ = bb_update.frame_builder.inverse(X, C)
    X = bb_update.frame_builder.forward(R, t, C)

    # Apply transformation
    node_h, edge_h, edge_idx, mask_i, mask_ij = pg(X, C)
    X_update, _, _, _ = bb_update(X, C, node_h, edge_h, edge_idx, mask_i, mask_ij)

    # Compute for rotated system
    X_rotate = _rotate(X)
    X_rotate_update, _, _, _ = bb_update(
        X_rotate, C, node_h, edge_h, edge_idx, mask_i, mask_ij
    )
    X_update_rotate = _rotate(X_update)

    assert torch.allclose(X_rotate_update, X_update_rotate, atol=1e-2)

    if output_structures:
        from chroma.layers.structure.rmsd import BackboneRMSD

        bb_rmsd = BackboneRMSD()
        X_aligned, rmsd = bb_rmsd.align(X_rotate_update, X_update_rotate, C)
        print(rmsd)
        Protein.from_XCS_trajectory(
            [X, X_update, X_rotate, X_rotate_update, X_update_rotate], C, S
        ).to_CIF("test_equi.cif")
    return