Spaces:
Sleeping
Sleeping
File size: 3,587 Bytes
ce7bf5b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 |
import numpy as np
import pytest
import torch
import torch.nn.functional as F
from chroma.data import Protein
from chroma.layers.structure.backbone import RigidTransformer
from chroma.layers.structure.protein_graph import ProteinFeatureGraph
def test_protein_feature_graph():
torch.manual_seed(10)
dim_nodes, dim_edges = 128, 64
num_neighbors = 30
feature_graph = ProteinFeatureGraph(
dim_nodes=dim_nodes,
dim_edges=dim_edges,
node_features=(("internal_coords", {"log_lengths": True}),),
edge_features=(
"distances_6mer",
"distances_2mer",
"orientations_2mer",
"distances_chain",
"orientations_chain",
"position_2mer",
),
num_neighbors=num_neighbors,
graph_kwargs={"mask_interfaces": False, "cutoff": None},
)
X, C, S = Protein("5imm").to_XCS()
node_h, edge_h, edge_idx, mask_i, mask_ij = feature_graph(X, C)
num_nodes = X.shape[1]
# Test shapes
assert node_h.shape == (1, num_nodes, dim_nodes)
assert edge_h.shape == (1, num_nodes, num_neighbors, dim_edges)
assert edge_idx.shape == (1, num_nodes, num_neighbors)
assert mask_i.shape == (1, num_nodes)
assert mask_ij.shape == (1, num_nodes, num_neighbors)
# Test masks
masked_sum_i = torch.abs((1.0 - mask_i).unsqueeze(-1) * node_h).sum()
masked_sum_ij = torch.abs((1.0 - mask_ij).unsqueeze(-1) * edge_h).sum()
assert masked_sum_i == 0
assert masked_sum_ij == 0
transformer = RigidTransformer(center_rotation=False)
q_rotate = torch.Tensor([0.0, 0.1, -1.2, 0.5]).unsqueeze(0)
dX_rotate = torch.Tensor([0.0, 1.0, -1.0]).unsqueeze(0)
_rotate = lambda X_input: transformer(X_input, dX_rotate, q_rotate)
# Test feature invariance to rotation and translation
X_transformed = _rotate(X)
node_h_r, edge_h_r, edge_idx_r, mask_i_r, mask_ij_r = feature_graph(
X_transformed, C
)
assert not torch.allclose(X, X_transformed, atol=1e-3)
assert torch.allclose(node_h_r, node_h, atol=1e-3)
assert torch.allclose(edge_h_r, edge_h, atol=1e-3)
assert torch.allclose(edge_idx, edge_idx_r, atol=1e-3)
def test_masked_interfaces():
torch.manual_seed(10)
dim_nodes, dim_edges = 128, 64
num_neighbors = 30
feature_graph = ProteinFeatureGraph(
dim_nodes=dim_nodes,
dim_edges=dim_edges,
node_features=(("internal_coords", {"log_lengths": True}),),
edge_features=(
("distances_6mer", {"require_contiguous": True}),
"distances_2mer",
"orientations_2mer",
"distances_chain",
"orientations_chain",
),
num_neighbors=num_neighbors,
graph_kwargs={"mask_interfaces": True, "cutoff": None},
)
X, C, S = Protein("5imm").to_XCS()
node_h, edge_h, edge_idx, mask_i, mask_ij = feature_graph(X, C)
num_nodes = X.shape[1]
# Test feature invariance to *single-chain* rotation and translation
rigid_transformer = RigidTransformer(center_rotation=False)
chain_mask = (C == 2).type(torch.float32)
q = torch.Tensor([2.0, 1.0, 1.0, 0.5]).unsqueeze(0)
dX = torch.Tensor([1.0, 2.0, -1.0]).unsqueeze(0)
X_transformed = rigid_transformer(X, dX, q, mask=chain_mask)
node_h_r, edge_h_r, edge_idx_r, mask_i_r, mask_ij_r = feature_graph(
X_transformed, C
)
assert torch.allclose(node_h_r, node_h, atol=1e-3)
assert torch.allclose(edge_h_r, edge_h, atol=1e-3)
assert torch.allclose(edge_idx, edge_idx_r, atol=1e-3)
|