Spaces:
Sleeping
Sleeping
from unittest import TestCase | |
import numpy as np | |
import pytest | |
import torch | |
from chroma.layers.graph import ( | |
MLP, | |
GraphLayer, | |
GraphNN, | |
collect_edges_transpose, | |
edge_mask_causal, | |
permute_graph_embeddings, | |
) | |
class Testcollect_edges_transpose(TestCase): | |
# Simple case of 3 noddes that are connected to each other | |
edge_idx = torch.tensor([[[1, 2], [0, 2], [0, 1]]]) | |
mask_ij = torch.tensor([[[1, 1], [1, 1], [1, 1]]]) | |
edge_h = torch.tensor([[[[1], [2]], [[3], [4]], [[5], [6]]]]) | |
edge_h_transpose, mask_ji = collect_edges_transpose(edge_h, edge_idx, mask_ij) | |
# Manually inspected the tensor so that it work | |
# I view(-1) so that it is easier to write | |
assert ( | |
torch.tensor([3.0, 5.0, 1.0, 6.0, 2.0, 4.0]) != edge_h_transpose.view(-1) | |
).detach().numpy().sum() == 0 | |
# Assert that shape stay the sample_input | |
assert edge_h.shape == edge_h_transpose.shape | |
# Kind of dumb, but if all mask, all edege shoudl be zero | |
edge_h_transpose, mask_ji = collect_edges_transpose( | |
edge_h, edge_idx, torch.zeros_like(mask_ij) | |
) | |
assert edge_h_transpose.abs().sum() == 0 | |
# Masking connection between 1,2 | |
mask_ij = torch.tensor([[[1, 1], [1, 0], [1, 0]]]) | |
edge_h_transpose, mask_ji = collect_edges_transpose(edge_h, edge_idx, mask_ij) | |
print(edge_h_transpose.view(-1)) | |
assert ( | |
torch.tensor([3.0, 5.0, 1.0, 0.0, 2.0, 0.0]) != edge_h_transpose.view(-1) | |
).detach().numpy().sum() == 0 | |
# Masking 0 vers 2 mais pas 2 vers 0 | |
# 2 vers 0 should be masked in the transpose | |
mask_ij = torch.tensor([[[1, 0], [1, 0], [1, 0]]]) | |
edge_h_transpose, mask_ji = collect_edges_transpose(edge_h, edge_idx, mask_ij) | |
assert ( | |
torch.tensor([3.0, 0.0, 1.0, 0.0, 0.0, 0.0]) != edge_h_transpose.view(-1) | |
).detach().numpy().sum() == 0 | |
class TestGraphNN(TestCase): | |
def test_sample(self): | |
dim_nodes = 128 | |
dim_edges = 64 | |
model = GraphNN(num_layers=6, dim_nodes=dim_nodes, dim_edges=dim_edges,) | |
num_nodes = 10 | |
num_neighbors = 8 | |
node_h_out, edge_h_out = model( | |
torch.ones(1, num_nodes, dim_nodes), | |
torch.ones(1, num_nodes, num_neighbors, dim_edges), | |
torch.ones(1, num_nodes, num_neighbors, dtype=torch.long), | |
) | |
self.assertTrue(node_h_out.shape == (1, num_nodes, dim_nodes)) | |
self.assertTrue(edge_h_out.shape == (1, num_nodes, num_neighbors, dim_edges)) | |
class TestGraphLayer(TestCase): | |
def test_sample(self): | |
dim_nodes = 128 | |
dim_edges = 64 | |
graph_layer = GraphLayer( | |
dim_nodes=dim_nodes, dim_edges=dim_edges, dropout=0, edge_update=True | |
) | |
num_parameters = sum([np.prod(p.size()) for p in graph_layer.parameters()]) | |
# self.assertEqual(num_parameters, 131712) | |
num_nodes = 10 | |
num_neighbors = 8 | |
node_h_out, edge_h_out = graph_layer( | |
torch.ones(1, num_nodes, dim_nodes), | |
torch.ones(1, num_nodes, num_neighbors, dim_edges), | |
torch.ones(1, num_nodes, num_neighbors, dtype=torch.long), | |
) | |
self.assertTrue(node_h_out.shape == (1, num_nodes, dim_nodes)) | |
self.assertTrue(edge_h_out.shape == (1, num_nodes, num_neighbors, dim_edges)) | |
class TestMLP(TestCase): | |
def test_sample(self): | |
dim_in = 10 | |
sample_input = torch.rand(dim_in) | |
prediction = MLP(dim_in)(sample_input) | |
self.assertTrue(prediction.shape[-1] == dim_in) | |
sample_input = torch.rand(dim_in) | |
dim_out = 8 | |
model = MLP(dim_in, dim_out=dim_out) | |
prediction = model(sample_input) | |
self.assertTrue(prediction.shape[-1] == dim_out) | |
sample_input = torch.rand(dim_in) | |
dim_hidden = 5 | |
model = MLP(dim_in, dim_hidden=5, dim_out=5) | |
prediction = model(sample_input) | |
self.assertTrue(prediction.shape[-1] == dim_hidden) | |
sample_input = torch.rand(dim_in) | |
model = MLP(dim_in, num_layers_hidden=0, dim_out=dim_out) | |
prediction = model(sample_input) | |
self.assertTrue(prediction.shape[-1] == dim_out) | |
class TestGraphFunctions(TestCase): | |
def hello(): | |
print("hello") | |
def test_graph_permutation(): | |
B, N, K, H = 2, 7, 4, 3 | |
# Create a random graph embedding | |
node_h = torch.randn([B, N, H]) | |
edge_h = torch.randn([B, N, K, H]) | |
edge_idx = torch.randint(low=0, high=N, size=[B, N, K]) | |
mask_i = torch.ones([B, N]) | |
mask_ij = torch.ones([B, N, K]) | |
# Create a random permutation matrix embedding | |
permute_idx = torch.argsort(torch.randn([B, N]), dim=-1) | |
# Permute | |
node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p = permute_graph_embeddings( | |
node_h, edge_h, edge_idx, mask_i, mask_ij, permute_idx | |
) | |
# Inverse permute | |
permute_idx_inverse = torch.argsort(permute_idx, dim=-1) | |
node_h_pp, edge_h_pp, edge_idx_pp, mask_i_pp, mask_ij_pp = permute_graph_embeddings( | |
node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p, permute_idx_inverse | |
) | |
# Test round-trip of permutation . inverse permutation | |
assert torch.allclose(node_h, node_h_pp) | |
assert torch.allclose(edge_h, edge_h_pp) | |
assert torch.allclose(edge_idx, edge_idx_pp) | |
assert torch.allclose(mask_i, mask_i_pp) | |
assert torch.allclose(mask_ij, mask_ij_pp) | |
# Test permutation equivariance of GNN layers | |
gnn = GraphNN(num_layers=1, dim_nodes=H, dim_edges=H) | |
outs = gnn(node_h, edge_h, edge_idx, mask_i, mask_ij) | |
outs_perm = gnn(node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p) | |
outs_pp = permute_graph_embeddings( | |
outs_perm[0], outs_perm[1], edge_idx_p, mask_i_p, mask_ij_p, permute_idx_inverse | |
) | |
assert torch.allclose(outs[0], outs_pp[0]) | |
assert torch.allclose(outs[1], outs_pp[1]) | |
return | |
def test_autoregressive_gnn(): | |
B, N, K, H = 1, 3, 3, 4 | |
torch.manual_seed(0) | |
# Build random GNN input | |
node_h = torch.randn([B, N, H]) | |
edge_h = torch.randn([B, N, K, H]) | |
# edge_idx = torch.randint(low=0, high=N, size=[B, N, K]) | |
edge_idx = torch.arange(K).reshape([1, 1, K]).expand([B, N, K]).contiguous() | |
mask_i = torch.ones([B, N]) | |
mask_ij = torch.ones([B, N, K]) | |
mask_ij = edge_mask_causal(edge_idx, mask_ij) | |
error = lambda x, y: (torch.abs(x - y) / (torch.abs(y) + 1e-3)).mean() | |
# Parallel mode computation | |
for mode in [True, False]: | |
gnn = GraphNN(num_layers=4, dim_nodes=H, dim_edges=H, attentional=mode) | |
node_h_gnn, edge_h_gnn = gnn(node_h, edge_h, edge_idx, mask_i, mask_ij) | |
# Step wise computation | |
node_h_cache, edge_h_cache = gnn.init_steps(node_h, edge_h) | |
for t in range(N): | |
node_h_cache, edge_h_cache = gnn.step( | |
t, node_h_cache, edge_h_cache, edge_idx, mask_i, mask_ij | |
) | |
node_h_sequential = node_h_cache[-1] | |
edge_h_sequential = edge_h_cache[-1] | |
assert torch.allclose(node_h_gnn, node_h_sequential) | |
assert torch.allclose(edge_h_gnn, edge_h_sequential) | |
return | |