File size: 7,089 Bytes
ce7bf5b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
from unittest import TestCase

import numpy as np
import pytest
import torch

from chroma.layers.graph import (
    MLP,
    GraphLayer,
    GraphNN,
    collect_edges_transpose,
    edge_mask_causal,
    permute_graph_embeddings,
)


class Testcollect_edges_transpose(TestCase):
    # Simple case of 3 noddes that are connected to each other
    edge_idx = torch.tensor([[[1, 2], [0, 2], [0, 1]]])
    mask_ij = torch.tensor([[[1, 1], [1, 1], [1, 1]]])
    edge_h = torch.tensor([[[[1], [2]], [[3], [4]], [[5], [6]]]])

    edge_h_transpose, mask_ji = collect_edges_transpose(edge_h, edge_idx, mask_ij)

    # Manually inspected the tensor so that it work
    # I view(-1) so that it is easier to write
    assert (
        torch.tensor([3.0, 5.0, 1.0, 6.0, 2.0, 4.0]) != edge_h_transpose.view(-1)
    ).detach().numpy().sum() == 0
    # Assert that shape stay the sample_input
    assert edge_h.shape == edge_h_transpose.shape

    # Kind of dumb, but if all mask, all edege shoudl be zero
    edge_h_transpose, mask_ji = collect_edges_transpose(
        edge_h, edge_idx, torch.zeros_like(mask_ij)
    )
    assert edge_h_transpose.abs().sum() == 0

    # Masking connection between 1,2
    mask_ij = torch.tensor([[[1, 1], [1, 0], [1, 0]]])
    edge_h_transpose, mask_ji = collect_edges_transpose(edge_h, edge_idx, mask_ij)
    print(edge_h_transpose.view(-1))
    assert (
        torch.tensor([3.0, 5.0, 1.0, 0.0, 2.0, 0.0]) != edge_h_transpose.view(-1)
    ).detach().numpy().sum() == 0

    # Masking 0 vers 2 mais pas 2 vers 0
    # 2 vers 0 should be masked in the transpose
    mask_ij = torch.tensor([[[1, 0], [1, 0], [1, 0]]])
    edge_h_transpose, mask_ji = collect_edges_transpose(edge_h, edge_idx, mask_ij)
    assert (
        torch.tensor([3.0, 0.0, 1.0, 0.0, 0.0, 0.0]) != edge_h_transpose.view(-1)
    ).detach().numpy().sum() == 0


class TestGraphNN(TestCase):
    def test_sample(self):
        dim_nodes = 128
        dim_edges = 64

        model = GraphNN(num_layers=6, dim_nodes=dim_nodes, dim_edges=dim_edges,)

        num_nodes = 10
        num_neighbors = 8
        node_h_out, edge_h_out = model(
            torch.ones(1, num_nodes, dim_nodes),
            torch.ones(1, num_nodes, num_neighbors, dim_edges),
            torch.ones(1, num_nodes, num_neighbors, dtype=torch.long),
        )
        self.assertTrue(node_h_out.shape == (1, num_nodes, dim_nodes))
        self.assertTrue(edge_h_out.shape == (1, num_nodes, num_neighbors, dim_edges))


class TestGraphLayer(TestCase):
    def test_sample(self):

        dim_nodes = 128
        dim_edges = 64

        graph_layer = GraphLayer(
            dim_nodes=dim_nodes, dim_edges=dim_edges, dropout=0, edge_update=True
        )

        num_parameters = sum([np.prod(p.size()) for p in graph_layer.parameters()])

        # self.assertEqual(num_parameters, 131712)

        num_nodes = 10
        num_neighbors = 8
        node_h_out, edge_h_out = graph_layer(
            torch.ones(1, num_nodes, dim_nodes),
            torch.ones(1, num_nodes, num_neighbors, dim_edges),
            torch.ones(1, num_nodes, num_neighbors, dtype=torch.long),
        )
        self.assertTrue(node_h_out.shape == (1, num_nodes, dim_nodes))
        self.assertTrue(edge_h_out.shape == (1, num_nodes, num_neighbors, dim_edges))


class TestMLP(TestCase):
    def test_sample(self):
        dim_in = 10
        sample_input = torch.rand(dim_in)
        prediction = MLP(dim_in)(sample_input)
        self.assertTrue(prediction.shape[-1] == dim_in)

        sample_input = torch.rand(dim_in)
        dim_out = 8
        model = MLP(dim_in, dim_out=dim_out)
        prediction = model(sample_input)
        self.assertTrue(prediction.shape[-1] == dim_out)

        sample_input = torch.rand(dim_in)
        dim_hidden = 5
        model = MLP(dim_in, dim_hidden=5, dim_out=5)
        prediction = model(sample_input)
        self.assertTrue(prediction.shape[-1] == dim_hidden)

        sample_input = torch.rand(dim_in)

        model = MLP(dim_in, num_layers_hidden=0, dim_out=dim_out)
        prediction = model(sample_input)
        self.assertTrue(prediction.shape[-1] == dim_out)


class TestGraphFunctions(TestCase):
    def hello():
        print("hello")


def test_graph_permutation():
    B, N, K, H = 2, 7, 4, 3
    # Create a random graph embedding
    node_h = torch.randn([B, N, H])
    edge_h = torch.randn([B, N, K, H])
    edge_idx = torch.randint(low=0, high=N, size=[B, N, K])
    mask_i = torch.ones([B, N])
    mask_ij = torch.ones([B, N, K])

    # Create a random permutation matrix embedding
    permute_idx = torch.argsort(torch.randn([B, N]), dim=-1)

    # Permute
    node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p = permute_graph_embeddings(
        node_h, edge_h, edge_idx, mask_i, mask_ij, permute_idx
    )

    # Inverse permute
    permute_idx_inverse = torch.argsort(permute_idx, dim=-1)
    node_h_pp, edge_h_pp, edge_idx_pp, mask_i_pp, mask_ij_pp = permute_graph_embeddings(
        node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p, permute_idx_inverse
    )

    # Test round-trip of permutation . inverse permutation
    assert torch.allclose(node_h, node_h_pp)
    assert torch.allclose(edge_h, edge_h_pp)
    assert torch.allclose(edge_idx, edge_idx_pp)
    assert torch.allclose(mask_i, mask_i_pp)
    assert torch.allclose(mask_ij, mask_ij_pp)

    # Test permutation equivariance of GNN layers
    gnn = GraphNN(num_layers=1, dim_nodes=H, dim_edges=H)
    outs = gnn(node_h, edge_h, edge_idx, mask_i, mask_ij)
    outs_perm = gnn(node_h_p, edge_h_p, edge_idx_p, mask_i_p, mask_ij_p)
    outs_pp = permute_graph_embeddings(
        outs_perm[0], outs_perm[1], edge_idx_p, mask_i_p, mask_ij_p, permute_idx_inverse
    )

    assert torch.allclose(outs[0], outs_pp[0])
    assert torch.allclose(outs[1], outs_pp[1])
    return


def test_autoregressive_gnn():
    B, N, K, H = 1, 3, 3, 4

    torch.manual_seed(0)

    # Build random GNN input
    node_h = torch.randn([B, N, H])
    edge_h = torch.randn([B, N, K, H])
    # edge_idx = torch.randint(low=0, high=N, size=[B, N, K])
    edge_idx = torch.arange(K).reshape([1, 1, K]).expand([B, N, K]).contiguous()
    mask_i = torch.ones([B, N])
    mask_ij = torch.ones([B, N, K])
    mask_ij = edge_mask_causal(edge_idx, mask_ij)

    error = lambda x, y: (torch.abs(x - y) / (torch.abs(y) + 1e-3)).mean()

    # Parallel mode computation
    for mode in [True, False]:
        gnn = GraphNN(num_layers=4, dim_nodes=H, dim_edges=H, attentional=mode)

        node_h_gnn, edge_h_gnn = gnn(node_h, edge_h, edge_idx, mask_i, mask_ij)

        # Step wise computation
        node_h_cache, edge_h_cache = gnn.init_steps(node_h, edge_h)
        for t in range(N):
            node_h_cache, edge_h_cache = gnn.step(
                t, node_h_cache, edge_h_cache, edge_idx, mask_i, mask_ij
            )
        node_h_sequential = node_h_cache[-1]
        edge_h_sequential = edge_h_cache[-1]

        assert torch.allclose(node_h_gnn, node_h_sequential)
        assert torch.allclose(edge_h_gnn, edge_h_sequential)
    return