Spaces:
Sleeping
Sleeping
import numpy as np | |
import pytest | |
import torch | |
import torch.nn.functional as F | |
from chroma.data import Protein | |
from chroma.layers.structure.backbone import RigidTransformer | |
from chroma.layers.structure.protein_graph import ProteinFeatureGraph | |
def test_protein_feature_graph(): | |
torch.manual_seed(10) | |
dim_nodes, dim_edges = 128, 64 | |
num_neighbors = 30 | |
feature_graph = ProteinFeatureGraph( | |
dim_nodes=dim_nodes, | |
dim_edges=dim_edges, | |
node_features=(("internal_coords", {"log_lengths": True}),), | |
edge_features=( | |
"distances_6mer", | |
"distances_2mer", | |
"orientations_2mer", | |
"distances_chain", | |
"orientations_chain", | |
"position_2mer", | |
), | |
num_neighbors=num_neighbors, | |
graph_kwargs={"mask_interfaces": False, "cutoff": None}, | |
) | |
X, C, S = Protein("5imm").to_XCS() | |
node_h, edge_h, edge_idx, mask_i, mask_ij = feature_graph(X, C) | |
num_nodes = X.shape[1] | |
# Test shapes | |
assert node_h.shape == (1, num_nodes, dim_nodes) | |
assert edge_h.shape == (1, num_nodes, num_neighbors, dim_edges) | |
assert edge_idx.shape == (1, num_nodes, num_neighbors) | |
assert mask_i.shape == (1, num_nodes) | |
assert mask_ij.shape == (1, num_nodes, num_neighbors) | |
# Test masks | |
masked_sum_i = torch.abs((1.0 - mask_i).unsqueeze(-1) * node_h).sum() | |
masked_sum_ij = torch.abs((1.0 - mask_ij).unsqueeze(-1) * edge_h).sum() | |
assert masked_sum_i == 0 | |
assert masked_sum_ij == 0 | |
transformer = RigidTransformer(center_rotation=False) | |
q_rotate = torch.Tensor([0.0, 0.1, -1.2, 0.5]).unsqueeze(0) | |
dX_rotate = torch.Tensor([0.0, 1.0, -1.0]).unsqueeze(0) | |
_rotate = lambda X_input: transformer(X_input, dX_rotate, q_rotate) | |
# Test feature invariance to rotation and translation | |
X_transformed = _rotate(X) | |
node_h_r, edge_h_r, edge_idx_r, mask_i_r, mask_ij_r = feature_graph( | |
X_transformed, C | |
) | |
assert not torch.allclose(X, X_transformed, atol=1e-3) | |
assert torch.allclose(node_h_r, node_h, atol=1e-3) | |
assert torch.allclose(edge_h_r, edge_h, atol=1e-3) | |
assert torch.allclose(edge_idx, edge_idx_r, atol=1e-3) | |
def test_masked_interfaces(): | |
torch.manual_seed(10) | |
dim_nodes, dim_edges = 128, 64 | |
num_neighbors = 30 | |
feature_graph = ProteinFeatureGraph( | |
dim_nodes=dim_nodes, | |
dim_edges=dim_edges, | |
node_features=(("internal_coords", {"log_lengths": True}),), | |
edge_features=( | |
("distances_6mer", {"require_contiguous": True}), | |
"distances_2mer", | |
"orientations_2mer", | |
"distances_chain", | |
"orientations_chain", | |
), | |
num_neighbors=num_neighbors, | |
graph_kwargs={"mask_interfaces": True, "cutoff": None}, | |
) | |
X, C, S = Protein("5imm").to_XCS() | |
node_h, edge_h, edge_idx, mask_i, mask_ij = feature_graph(X, C) | |
num_nodes = X.shape[1] | |
# Test feature invariance to *single-chain* rotation and translation | |
rigid_transformer = RigidTransformer(center_rotation=False) | |
chain_mask = (C == 2).type(torch.float32) | |
q = torch.Tensor([2.0, 1.0, 1.0, 0.5]).unsqueeze(0) | |
dX = torch.Tensor([1.0, 2.0, -1.0]).unsqueeze(0) | |
X_transformed = rigid_transformer(X, dX, q, mask=chain_mask) | |
node_h_r, edge_h_r, edge_idx_r, mask_i_r, mask_ij_r = feature_graph( | |
X_transformed, C | |
) | |
assert torch.allclose(node_h_r, node_h, atol=1e-3) | |
assert torch.allclose(edge_h_r, edge_h, atol=1e-3) | |
assert torch.allclose(edge_idx, edge_idx_r, atol=1e-3) | |