|
The BaselineModel class in baselines.py file is a full working Graph Neural Network (GNN) example using JAX and the DeepMind JAX Ecosystem of libraries. It allows training of multiple algorithms on a single processor, as described in the paper "A Generalist Neural Algorithmic Learner" (arXiv:2209.11142v2 [cs.LG] 3 Dec 2022). Below is an excerpt from the paper that describes the model: |
|
|
|
Each algorithm in the CLRS benchmark [5] is specified by a number of inputs, hints and outputs. In |
|
a given sample, the inputs and outputs are fixed, while hints are time-series of intermediate states of |
|
the algorithm. Each sample for a particular task has a size, n, corresponding to the number of nodes |
|
in the GNN that will execute the algorithm. |
|
A sample of every algorithm is represented as a graph, with each input, output and hint located in |
|
either the nodes, the edges, or the graph itself, and therefore has shape (excluding batch dimension, |
|
and, for hints, time dimension) n × f , n × n × f , or f , respectively, f being the dimensionality of |
|
the feature, which depends on its type. The CLRS benchmark defines five types of features: scalar, |
|
categorical, mask, mask_one and pointer, with their own encoding and decoding strategies and |
|
loss functions—e.g. a scalar type will be encoded and decoded directly by a single linear layer, and |
|
optimised using mean squared error. |
|
|
|
Base Model |
|
|
|
Encoder. We adopt the same encode-process-decode paradigm [33] presented with the CLRS |
|
benchmark [5]. At each time step, t, of a particular task τ (e.g. insertion sort), the task-based encoder |
|
fτ , consisting of a linear encoder for each input and hint, embeds inputs and the current hints as |
|
high-dimensional vectors. These embeddings of inputs and hints located in the nodes all have the |
|
same dimension and are added together; the same happens with hints and inputs located in edges, |
|
and in the graph. In our experiments we use the same dimension, h = 128, for node, edge and graph |
|
3 |
|
|
|
A Generalist Neural Algorithmic Learner |
|
|
|
embeddings. Thus, at the |
|
step for a time-step t of the algorithm, we have a |
|
n end of the encoding |
|
o |
|
(t) (t) |
|
(t) |
|
single set of embeddings xi , eij , g |
|
, shapes n × h, n × n × h, and h, in the nodes, edges and |
|
graph, respectively. Note that this is independent of the number and type of the inputs and hints of |
|
the particular algorithm, allowing us to share this latent space across all thirty algorithms in CLRS. |
|
Further, note that at each step, the input encoding is fed directly to these embeddings—this recall |
|
mechanism significantly improves the model’s robustness over long trajectories [34]. |
|
Processor. The embeddings are fed into a processor P , a GNN that performs one step of computation. The processor transforms the input node, edge and graph embeddings into processed |
|
(t) |
|
node embeddings, hi . Additionally, the processor uses the processed node embeddings from the |
|
(t−1) |
|
previous step, hi |
|
, as inputs. Importantly, the same processor model can operate on graphs of any |
|
size. We leverage the message-passing neural network [35, MPNN], using the max aggregation and |
|
passing messages over a fully-connected graph, as our base model. The MPNN computes processed |
|
embeddings as follows: |
|
|
|
|
|
|
|
|
|
(t) |
|
(t−1) |
|
(t) |
|
(t) (t) (t) |
|
(t) |
|
(t) |
|
(t) |
|
z(t) = xi khi |
|
mi = max fm zi , zj , eij , g(t) |
|
hi = fr zi , mi |
|
(1) |
|
1≤j≤n |
|
|
|
starting from h(0) = 0. Here k denotes concatenation, fm : R2h × R2h × Rh × Rh → Rh is the |
|
message function (for which we use a three-layer MLP with ReLU activations), and fr : R2h × Rh → |
|
Rh is the readout function (for which we use a linear layer with ReLU activation). The use of the max |
|
aggregator is well-motivated by prior work [5, 9], and we use the fully connected graph—letting the |
|
neighbours j range over all nodes (1 ≤ j ≤ n)—in order to allow the model to overcome situations |
|
(t) |
|
where the input graph structure may be suboptimal. Layer normalisation [36] is applied to hi before |
|
using them further. Further details on the MPNN processor may be found in Veličković et al. [5]. |
|
Decoder. The processed embeddings are finally decoded with a task-based decoder gτ , to predict |
|
the hints for the next step, and the outputs at the final step. Akin to the encoder, the task-based decoder |
|
relies mainly on a linear decoder for each hint and output, along with a mechanism to compute |
|
pairwise node similarities when appropriate. Specifically, the pointer type decoder computes |
|
a score, sij , for each pair of nodes, and then chooses the pointer of node i by taking either the |
|
argmaxj sij or softmaxj sij (depending on whether a hard or soft prediction is used). |
|
Loss. The decoded hints and outputs are used to compute the loss during training, according to their |
|
type [5]. For each sample in a batch, the hint prediction losses are averaged across hints and time, |
|
and the output loss is averaged across outputs (most algorithms have a single output, though some |
|
have two outputs). The hint loss and output loss are added together. Besides, the hint predictions at |
|
each time step are fed back as inputs for the next step, except possibly at train time if teacher forcing |
|
is used (see Section 3.2.1). |
|
We train the model on samples with sizes n ≤ 16, and periodically evaluate them on in-distribution |
|
samples of size n = 16. Also, periodically, we evaluate the model with the best in-distribution |
|
evaluation score so far on OOD samples of size n = 64. In what follows, we will be reporting only |
|
these OOD evaluation scores. Full details of the model, training and evaluation hyperparameters can |
|
be found in Appendix A. |
|
3.2 |
|
|
|
Model improvements |
|
|
|
As previously discussed, single-task improvements, especially in terms of learning stability, will |
|
empirically transfer well to multi-task algorithmic learning. We now describe, in a gradual manner, |
|
all the changes made to the model, which have lead to an absolute improvement of over 20% on |
|
average across all 30 tasks in CLRS. |
|
3.2.1 |
|
|
|
Dataset and training |
|
|
|
Removing teacher forcing. At evaluation time, the model has no access to the step-by-step hints |
|
in the dataset, and has to rely on its own hint predictions. However, during training, it is sometimes |
|
advisable to stabilise the trajectories with teacher forcing [37]—providing the ground-truth hint |
|
values instead of the network’s own predictions. In the prior model [5], ground-truth hints were |
|
4 |
|
|
|
A Generalist Neural Algorithmic Learner |
|
|
|
provided during training with probability 0.5, as, without teacher forcing, losses tended to grow |
|
unbounded along a trajectory when scalar hints were present, destabilising the training. In this |
|
work we incorporate several significant stabilising changes (described in future paragraphs), which |
|
allows us to remove teacher forcing altogether, aligning training with evaluation, and avoiding the |
|
network becoming overconfident in always expecting correct hint predictions. With teacher forcing, |
|
performance deteriorates significantly in sorting algorithms and Kruskal’s algorithm. Naïve String |
|
Matcher, on the other hand, improves with teacher forcing (see Appendix A, Figs. 7-9). |
|
Augmenting the training data. To prevent our model from over-fitting to the statistics of the fixed |
|
CLRS training dataset [5], we augmented the training data in three key ways, without breaking |
|
the intended size distribution shift. Firstly, we used the on-line samplers in CLRS to generate new |
|
training examples on the fly, rather than using a fixed dataset which is easier to overfit to. Secondly, |
|
we trained on examples of mixed sizes, n ≤ 16, rather than only 16, which helps the model anticipate |
|
for a diverse range of sizes, rather than overfitting to the specifics of size n = 16. Lastly, for graph |
|
algorithms, we varied the connectivity probability p of the input graphs (generated by the Erdős-Rényi |
|
model [38]); and for string matching algorithms, we varied the length of the pattern to be matched. |
|
These both serve to expose the model to different trajectory lengths; for example, in many graph |
|
algorithms, the amount of steps the algorithm should run for is related to the graph’s diameter, and |
|
varying the connection probability in the graph generation allows for varying the expected diameter. |
|
These changes considerably increase training data variability, compared to the original dataset in |
|
Veličković et al. [5]. We provide a more detailed step-by-step overview of the data generation process |
|
in Appendix A. |
|
Soft hint propagation. When predicted hints are fed back as inputs during training, gradients |
|
may or may not be allowed to flow through them. In previous work, only hints of the scalar type |
|
allowed gradients through, as all categoricals were post-processed from logits into the ground-truth |
|
format via argmax or thresholding before being fed back. Instead, in this work we use softmax |
|
for categorical, mask_one and pointer types, and the logistic sigmoid for mask types. Without |
|
these soft hints, performance in sorting algorithms degrades (similarly to the case of teacher forcing), |
|
as well as in Naïve String Matcher (Appendix A, Figs. 7-9). |
|
Static hint elimination. Eleven algorithms in CLRS3 specify a fixed ordering of the nodes, common |
|
to every sample, via a node pointer hint that does not ever change along the trajectories. Prediction of |
|
this hint is trivial (identity function), but poses a potential problem for OOD generalisation, since the |
|
model can overfit to the fixed training values. We therefore turned this fixed hint into an input for |
|
these 11 algorithms, eliminating the need for explicitly predicting it. |
|
Improving training stability with encoder initialisation and gradient clipping. The scalar |
|
hints have unbounded values, in principle, and are optimised using mean-squared error, hence their |
|
gradients can quickly grow with increasing prediction error. Further, the predicted scalar hints then |
|
get re-encoded at every step, which can rapidly amplify errors throughout the trajectory, leading to |
|
exploding signals (and consequently gradients), even before any training takes place. |
|
To rectify this issue, we use the Xavier initialisation [45], effectively reducing the initial weights for |
|
scalar hints whose input dimensionality is just 1. However, we reverted to using the default LeCun |
|
initialisation [46] elsewhere. This combination of initialisations proved important for the initial |
|
learning stability of our model over long trajectories. Relatedly, in preliminary experiments, we saw |
|
drastic improvements in learning stability, as well as significant increases in validation performance, |
|
with gradient clipping [47], which we subsequently employed in all experiments. |
|
3.2.2 |
|
|
|
Encoders and decoders |
|
|
|
Randomised position scalar. Across all algorithms in the dataset, there exists a position scalar |
|
input which uniquely indexes the nodes, with values linearly spaced between 0 and 1 along the node |
|
index. To avoid overfitting to these linearly spaced values during training, we replaced them with |
|
random values, uniformly sampled in [0, 1], sorted to match the initial order implied by the linearly |
|
spaced values. The benefit of this change is notable in algorithms where it would be easy to overfit to |
|
3 |
|
|
|
Binary Search, Minimum, Max Subarray [39], Matrix Chain Order, LCS Length, Optimal BST [40], Activity |
|
Selector [41], Task Scheduling [42], Naïve String Matcher, Knuth-Morris-Pratt [43] and Jarvis’ March [44]. |
|
5 |
|
|
|
A Generalist Neural Algorithmic Learner |
|
|
|
these positions, such as string matching. Namely, the model could learn to base all of its computations |
|
on the assumption that it will always be finding a m-character pattern inside an n-character string, |
|
even though at test time, m and n will increase fourfold. |
|
Permutation decoders and the Sinkhorn operator. Sorting algorithms (Insertion Sort, Bubble |
|
Sort, Heapsort [48] and Quicksort [49]) always output a permutation of the input nodes. In the CLRS |
|
benchmark, this permutation is encoded as a pointer where each node points to its predecessor in |
|
the sorted order (the first node points to itself); this is represented as a n × n matrix P where each |
|
row is a one-hot vector, such that element (i, j) is 1 if node i points to node j. As with all types of |
|
pointers, such permutation pointers can be predicted using a row-wise softmax on unconstrained |
|
decoder outputs (logits), trained with cross entropy (as in Veličković et al. [5]). However, this does |
|
not explicitly take advantage of the fact that the pointers encode a permutation, which the model |
|
has to learn instead. Our early experiments showed that the model was often failing to predict valid |
|
permutations OOD. |
|
Accordingly, we enforce a permutation inductive bias in the output decoder of sorting algorithms, as |
|
follows. First, we modify the output representation by rewiring the first node to point to the last one, |
|
turning P into a permutation matrix, i.e., a matrix whose rows and columns are one-hot vectors. We |
|
also augment the representation with a one-hot vector of size n that specifies the first node, so we do |
|
not lose this information; this vector is treated like a regular mask_one feature. Second, we predict the |
|
permutation matrix P from unconstrained decoder outputs Y by replacing the usual row-wise softmax |
|
with the Sinkhorn operator S [32, 50–53]. S projects an arbitrary square matrix Y into a doubly |
|
stochastic matrix S(Y) (a non-negative matrix whose rows and columns sum to 1), by exponentiating |
|
and repeatedly normalizing rows and columns so they sum to 1. Specifically, S is defined by: |
|
S 0 (Y) = exp(Y) |
|
|
|
S l (Y) = Tc (Tr (S l−1 (Y))) |
|
|
|
S(Y) = lim S l (Y), |
|
l→∞ |
|
|
|
(2) |
|
|
|
where exp acts element-wise, and Tr and Tc denote row and column normalisation respectively. |
|
Although the Sinkhorn operator produces a doubly stochastic matrix rather than a permutation matrix, |
|
we can obtain a permutation matrix by introducing a temperature parameter, τ > 0, and taking |
|
P = limτ →0+ S(Y/τ ); as long as there are no ties in the elements of Y, P is guaranteed to be a |
|
permutation matrix [52, Theorem 1]. |
|
In practice, we compute the Sinkhorn operator using a fixed number of iterations lmax . We use a |
|
smaller number of iterations lmax = 10 for training, to limit vanishing and exploding gradients, and |
|
lmax = 60 for evaluation. A fixed temperature τ = 0.1 was experimentally found to give a good |
|
balance between speed of convergence and tie-breaking. We also encode the fact that no node points |
|
to itself, that is, that all diagonal elements of P should be 0, by setting the diagonal elements of Y to |
|
−∞. To avoid ties, we follow Mena et al. [53], injecting Gumbel noise to the elements of Y prior to |
|
applying the Sinkhorn operator, during training only. Finally, we transform the predicted matrix P, |
|
and mask_one pointing to the first element, into the original pointer representation used by CLRS. |
|
3.2.3 |
|
|
|
Processor networks |
|
|
|
Gating mechanisms. Many algorithms only require updating a few nodes at each time step, keeping |
|
the rest unchanged. However, the MPNN we use (Equation 1) is biased towards the opposite: it |
|
updates all hidden states in each step. Although it is theoretically possible for the network to keep the |
|
states unchanged, learning to do so is not easy. With this in mind, and motivated by its effectiveness |
|
in NDRs [54], we augment the network with an update gate, biased to be closed by default. We |
|
found that the gate stabilizes learning on many of the tasks, and increases the mean performance |
|
over all tasks on single-task training significantly. Surprisingly, however, we did not find gating to be |
|
advantageous in the multi-task case. |
|
To add gating to the MPNN model we produce a per-node gating vector from the same inputs that |
|
process the embeddings in Equation 1: |
|
|
|
|
|
(t) |
|
(t) |
|
(t) |
|
gi = fg zi , mi |
|
(3) |
|
where fg : R2h × Rh → Rh is the gating function, for which we use a two-layer MLP, with |
|
ReLU activation for the hidden layer and logistic sigmoid activation for the output. Importantly, the |
|
final layer bias of fg is initialized to a value of −3, which biases the network for not updating its |
|
6 |
|
|
|
A Generalist Neural Algorithmic Learner |
|
|
|
Our model |
|
Previous SOTA [5] |
|
|
|
80 |
|
60 |
|
40 |
|
|
|
Quickselect |
|
|
|
Heapsort |
|
|
|
Knuth-Morris-Pratt |
|
|
|
Strongly Conn. Comps. |
|
|
|
DFS |
|
|
|
Floyd-Warshall |
|
|
|
Quicksort |
|
|
|
Bubble Sort |
|
|
|
Optimal BST |
|
|
|
Find Max. Subarray |
|
|
|
Insertion Sort |
|
|
|
Binary Search |
|
|
|
LCS Length |
|
|
|
Naïve String Matcher |
|
|
|
MST Prim |
|
|
|
Topological Sort |
|
|
|
Task Scheduling |
|
|
|
MST Kruskal |
|
|
|
Articulation Points |
|
|
|
Jarvis' March |
|
|
|
Matrix Chain Order |
|
|
|
Bridges |
|
|
|
Graham Scan |
|
|
|
Dijkstra |
|
|
|
Activity Selector |
|
|
|
Bellman-Ford |
|
|
|
DAG Shortest Paths |
|
|
|
Segments Intersect |
|
|
|
0 |
|
|
|
BFS |
|
|
|
20 |
|
Minimum |
|
|
|
Average score [%] |
|
|
|
100 |
|
|
|
Figure 2: The OOD performance in single-task experiments before and after the improvements |
|
presented in this paper, sorted in descending order of current performance. Error bars represent |
|
standard error of the mean across seeds (3 seeds for previous SOTA experiments, 10 seeds for current). |
|
The previous SOTA values are the best of MPNN, PGN and Memnet models (see Table 2). |
|
b (t) , are computed as follows: |
|
representations, unless necessary. The processed gated embeddings, h |
|
i |
|
b (t) = g(t) |
|
h |
|
i |
|
i |
|
and are used instead of |
|
|
|
(t) |
|
hi |
|
|
|
(t) |
|
|
|
(t) |
|
|
|
hi + (1 − gi ) |
|
|
|
in the subsequent steps, replacing z |
|
|
|
(t−1) |
|
|
|
hi |
|
(t) |
|
|
|
(4) |
|
|
|
in Eq. 1 by z |
|
|
|
(t) |
|
|
|
= |
|
|
|
(t) b (t−1) |
|
xi kh |
|
. |
|
i |
|
|
|
Triplet reasoning. Several algorithms within CLRS-30 explicitly require edge-based reasoning— |
|
where edges store values, and update them based on other edges’ values. An example of this is the |
|
Floyd-Warshall algorithm [55], which computes all-pairs shortest paths in a weighted graph. The |
|
update rule for dij , its estimate for the best distance from node i to j, is dij = mink dik + dkj , which |
|
roughly says “the best way to get from i to j is to find the optimal mid-point k, travel from i to k, then |
|
from k to j”. Similar rules are pervasive across many CLRS-30 algorithms, especially in dynamic |
|
programming. Even though there are no node representations in the above update, all our processors |
|
are centered on passing messages between node representations hi . |
|
To rectify this situation, we augment our processor to perform message passing towards edges. |
|
Referring again to the update for dij , we note that the edge representations are updated by choosing |
|
an intermediate node, then aggregating over all possible choices. Accordingly, and as previously observed by Dudzik and Veličković [31], we introduce triplet reasoning: first, computing representations |
|
over triplets of nodes, then reducing over one node to obtain edge latents: |
|
tijk = ψt (hi , hj , hk , eij , eik , ekj , g) |
|
hij = φt (max tijk ) |
|
(5) |
|
k |
|
|
|
Here, ψt is a triplet message function, mapping all relevant representations to a single vector for |
|
each triplet of nodes, and φt is an edge readout function, which transforms the aggregated triplets |
|
for each edge for later use. According to prior findings on the CLRS benchmark [5], we use the |
|
max aggregation to obtain edge representations. The computed hij vectors can then be used in any |
|
edge-based reasoning task, and empirically they are indeed significantly beneficial, even in tasks |
|
where we did not initially anticipate such benefits. One example is Kruskal’s minimum spanning tree |
|
algorithm [56], where we presume that access to triplet reasoning allowed the model to more easily |
|
sort the edges by weight, as it selects how to augment the spanning forest at each step. |
|
In order to keep the footprint of triplet embeddings as lightweight as possible, we compute only |
|
8-dimensional features in ψt . φt then upscales the aggregated edge features back to 128 dimensions, |
|
to make them compatible with the rest of the architecture. Our initial experimentation demonstrated |
|
that the output dimensionality of ψt did not significantly affect downstream performance. Note that |
|
computing triplet representations has been a useful approach in general GNN design [57]—however, |
|
it has predominantly been studied in the context of GNNs over constant input features. Our study is |
|
among the first to verify their utility over reasoning tasks with well-specified initial features. |
|
3.3 |
|
|
|
Results |
|
|
|
By incorporating the changes described in the previous sections we arrived at a single model type, |
|
with a single set of hyper-parameters, that was trained to reach new state-of-the-art performance |
|
7 |
|
|
|
A Generalist Neural Algorithmic Learner |
|
|
|
Table 1: Single-task OOD micro-F1 score of previous SOTA Memnet, MPNN and PGN [5] and our |
|
best model Triplet-GMPNN with all our improvements, after 10,000 training steps. |
|
Alg. Type |
|
|
|
Memnet [5] |
|
|
|
MPNN [5] |
|
|
|
PGN [5] |
|
|
|
Triplet-GMPNN (ours) |
|
|
|
Div. & C. |
|
DP |
|
Geometry |
|
Graphs |
|
Greedy |
|
Search |
|
Sorting |
|
Strings |
|
|
|
13.05% ± 0.14 |
|
67.94% ± 8.20 |
|
45.14% ± 11.95 |
|
24.12% ± 5.30 |
|
53.42% ± 20.82 |
|
34.35% ± 21.67 |
|
71.53% ± 1.41 |
|
1.51% ± 0.46 |
|
|
|
20.30% ± 0.85 |
|
65.10% ± 6.44 |
|
73.11% ± 17.19 |
|
62.79% ± 8.75 |
|
82.39% ± 3.01 |
|
41.20% ± 19.87 |
|
11.83% ± 2.78 |
|
3.21% ± 0.94 |
|
|
|
65.23% ± 4.44 |
|
70.58% ± 6.48 |
|
61.19% ± 7.01 |
|
60.25% ± 8.42 |
|
75.84% ± 6.59 |
|
56.11% ± 21.56 |
|
15.45% ± 8.46 |
|
2.04% ± 0.20 |
|
|
|
76.36% ± 1.34 |
|
81.99% ± 4.98 |
|
94.09% ± 2.30 |
|
81.41% ± 6.21 |
|
91.21% ± 2.95 |
|
58.61% ± 24.34 |
|
60.37% ± 12.16 |
|
49.09% ± 23.49 |
|
|
|
38.88% |
|
|
|
44.99% |
|
|
|
50.84% |
|
|
|
74.14% |
|
|
|
0/30 |
|
3/30 |
|
10/30 |
|
|
|
6/30 |
|
9/30 |
|
14/30 |
|
|
|
3/30 |
|
7/30 |
|
15/30 |
|
|
|
11/30 |
|
17/30 |
|
24/30 |
|
|
|
Overall avg. |
|
> 90% |
|
> 80% |
|
> 60% |
|
|
|
on CLRS-30 [5]. Tables 1 and 2 show the micro-F1 scores of our model, which we refer to as |
|
Triplet-GMPNN (an MPNN with gating and triplet edge processing), over the original CLRS-30 test |
|
set (computed identically to Veličković et al. [5], but with 10 repetitions instead of 3). Our baselines |
|
include the Memnet [58], MPNN [35] and PGN [59] models, taken directly from Veličković et al. [5]. |
|
Figure 2 displays the comparison between the improved model and the best model from Veličković |
|
et al. [5]. Our improvements lead to an overall average performance that is more than 20% higher |
|
(in absolute terms) compared to the next best model (see Table 1), and to a significant performance |
|
improvement in all but one algorithm family, compared to every other model. Further, our stabilising |
|
changes (such as gradient clipping) have empirically reduced the scale of our model’s gradient |
|
updates across the 30 tasks, preparing us better for the numerical issues of the multi-task regime. We |
|
finally also note that though we do not show it in Tables 1 & 2, applying the same improvements to |
|
the PGN processor, leads to an increase in overall performance from 50.84% (Table 1) to 69.31%. |
|
There are two notable examples of algorithm families with significant OOD performance improvement. |
|
The first are geometric algorithms (Segments Intersect, Graham Scan [60] and Jarvis’ March), now |
|
solved at approximately 94% OOD, compared to the previous best of about 73%; the second being |
|
string algorithms (Knuth-Morris-Pratt and Naïve String Matcher) for which our model now exceeds |
|
49% compared to the previous best of approximately 3%. |
|
The significant overall performance boost is reflected in the increased number of algorithms we can |
|
now solve at over 60%, 80% & 90% OOD performance, compared to previous SOTA [5]. Specifically, |
|
we now exceed 60% accuracy in 24 algorithms (15 algorithms previously), 80% for 17 algorithms (9 |
|
previously) and 90% for 11 algorithms (6 previously). |
|
|