|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""JAX implementation of baseline processor networks.""" |
|
|
|
import abc |
|
from typing import Any, Callable, List, Optional, Tuple |
|
|
|
import chex |
|
import haiku as hk |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
|
|
|
|
_Array = chex.Array |
|
_Fn = Callable[..., Any] |
|
BIG_NUMBER = 1e6 |
|
PROCESSOR_TAG = 'clrs_processor' |
|
|
|
|
|
class Processor(hk.Module): |
|
"""Processor abstract base class.""" |
|
|
|
def __init__(self, name: str): |
|
if not name.endswith(PROCESSOR_TAG): |
|
name = name + '_' + PROCESSOR_TAG |
|
super().__init__(name=name) |
|
|
|
@abc.abstractmethod |
|
def __call__( |
|
self, |
|
node_fts: _Array, |
|
edge_fts: _Array, |
|
graph_fts: _Array, |
|
adj_mat: _Array, |
|
hidden: _Array, |
|
**kwargs, |
|
) -> Tuple[_Array, Optional[_Array]]: |
|
"""Processor inference step. |
|
|
|
Args: |
|
node_fts: Node features. |
|
edge_fts: Edge features. |
|
graph_fts: Graph features. |
|
adj_mat: Graph adjacency matrix. |
|
hidden: Hidden features. |
|
**kwargs: Extra kwargs. |
|
|
|
Returns: |
|
Output of processor inference step as a 2-tuple of (node, edge) |
|
embeddings. The edge embeddings can be None. |
|
""" |
|
pass |
|
|
|
@property |
|
def inf_bias(self): |
|
return False |
|
|
|
@property |
|
def inf_bias_edge(self): |
|
return False |
|
|
|
|
|
class GAT(Processor): |
|
"""Graph Attention Network (Velickovic et al., ICLR 2018).""" |
|
|
|
def __init__( |
|
self, |
|
out_size: int, |
|
nb_heads: int, |
|
activation: Optional[_Fn] = jax.nn.relu, |
|
residual: bool = True, |
|
use_ln: bool = False, |
|
name: str = 'gat_aggr', |
|
): |
|
super().__init__(name=name) |
|
self.out_size = out_size |
|
self.nb_heads = nb_heads |
|
if out_size % nb_heads != 0: |
|
raise ValueError('The number of attention heads must divide the width!') |
|
self.head_size = out_size // nb_heads |
|
self.activation = activation |
|
self.residual = residual |
|
self.use_ln = use_ln |
|
|
|
def __call__( |
|
self, |
|
node_fts: _Array, |
|
edge_fts: _Array, |
|
graph_fts: _Array, |
|
adj_mat: _Array, |
|
hidden: _Array, |
|
**unused_kwargs, |
|
) -> _Array: |
|
"""GAT inference step.""" |
|
|
|
b, n, _ = node_fts.shape |
|
assert edge_fts.shape[:-1] == (b, n, n) |
|
assert graph_fts.shape[:-1] == (b,) |
|
assert adj_mat.shape == (b, n, n) |
|
|
|
z = jnp.concatenate([node_fts, hidden], axis=-1) |
|
m = hk.Linear(self.out_size) |
|
skip = hk.Linear(self.out_size) |
|
|
|
bias_mat = (adj_mat - 1.0) * 1e9 |
|
bias_mat = jnp.tile(bias_mat[..., None], |
|
(1, 1, 1, self.nb_heads)) |
|
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) |
|
|
|
a_1 = hk.Linear(self.nb_heads) |
|
a_2 = hk.Linear(self.nb_heads) |
|
a_e = hk.Linear(self.nb_heads) |
|
a_g = hk.Linear(self.nb_heads) |
|
|
|
values = m(z) |
|
values = jnp.reshape( |
|
values, |
|
values.shape[:-1] + (self.nb_heads, self.head_size)) |
|
values = jnp.transpose(values, (0, 2, 1, 3)) |
|
|
|
att_1 = jnp.expand_dims(a_1(z), axis=-1) |
|
att_2 = jnp.expand_dims(a_2(z), axis=-1) |
|
att_e = a_e(edge_fts) |
|
att_g = jnp.expand_dims(a_g(graph_fts), axis=-1) |
|
|
|
logits = ( |
|
jnp.transpose(att_1, (0, 2, 1, 3)) + |
|
jnp.transpose(att_2, (0, 2, 3, 1)) + |
|
jnp.transpose(att_e, (0, 3, 1, 2)) + |
|
jnp.expand_dims(att_g, axis=-1) |
|
) |
|
coefs = jax.nn.softmax(jax.nn.leaky_relu(logits) + bias_mat, axis=-1) |
|
ret = jnp.matmul(coefs, values) |
|
ret = jnp.transpose(ret, (0, 2, 1, 3)) |
|
ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) |
|
|
|
if self.residual: |
|
ret += skip(z) |
|
|
|
if self.activation is not None: |
|
ret = self.activation(ret) |
|
|
|
if self.use_ln: |
|
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) |
|
ret = ln(ret) |
|
|
|
return ret, None |
|
|
|
|
|
class GATFull(GAT): |
|
"""Graph Attention Network with full adjacency matrix.""" |
|
|
|
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, |
|
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: |
|
adj_mat = jnp.ones_like(adj_mat) |
|
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) |
|
|
|
|
|
class GATv2(Processor): |
|
"""Graph Attention Network v2 (Brody et al., ICLR 2022).""" |
|
|
|
def __init__( |
|
self, |
|
out_size: int, |
|
nb_heads: int, |
|
mid_size: Optional[int] = None, |
|
activation: Optional[_Fn] = jax.nn.relu, |
|
residual: bool = True, |
|
use_ln: bool = False, |
|
name: str = 'gatv2_aggr', |
|
): |
|
super().__init__(name=name) |
|
if mid_size is None: |
|
self.mid_size = out_size |
|
else: |
|
self.mid_size = mid_size |
|
self.out_size = out_size |
|
self.nb_heads = nb_heads |
|
if out_size % nb_heads != 0: |
|
raise ValueError('The number of attention heads must divide the width!') |
|
self.head_size = out_size // nb_heads |
|
if self.mid_size % nb_heads != 0: |
|
raise ValueError('The number of attention heads must divide the message!') |
|
self.mid_head_size = self.mid_size // nb_heads |
|
self.activation = activation |
|
self.residual = residual |
|
self.use_ln = use_ln |
|
|
|
def __call__( |
|
self, |
|
node_fts: _Array, |
|
edge_fts: _Array, |
|
graph_fts: _Array, |
|
adj_mat: _Array, |
|
hidden: _Array, |
|
**unused_kwargs, |
|
) -> _Array: |
|
"""GATv2 inference step.""" |
|
|
|
b, n, _ = node_fts.shape |
|
assert edge_fts.shape[:-1] == (b, n, n) |
|
assert graph_fts.shape[:-1] == (b,) |
|
assert adj_mat.shape == (b, n, n) |
|
|
|
z = jnp.concatenate([node_fts, hidden], axis=-1) |
|
m = hk.Linear(self.out_size) |
|
skip = hk.Linear(self.out_size) |
|
|
|
bias_mat = (adj_mat - 1.0) * 1e9 |
|
bias_mat = jnp.tile(bias_mat[..., None], |
|
(1, 1, 1, self.nb_heads)) |
|
bias_mat = jnp.transpose(bias_mat, (0, 3, 1, 2)) |
|
|
|
w_1 = hk.Linear(self.mid_size) |
|
w_2 = hk.Linear(self.mid_size) |
|
w_e = hk.Linear(self.mid_size) |
|
w_g = hk.Linear(self.mid_size) |
|
|
|
a_heads = [] |
|
for _ in range(self.nb_heads): |
|
a_heads.append(hk.Linear(1)) |
|
|
|
values = m(z) |
|
values = jnp.reshape( |
|
values, |
|
values.shape[:-1] + (self.nb_heads, self.head_size)) |
|
values = jnp.transpose(values, (0, 2, 1, 3)) |
|
|
|
pre_att_1 = w_1(z) |
|
pre_att_2 = w_2(z) |
|
pre_att_e = w_e(edge_fts) |
|
pre_att_g = w_g(graph_fts) |
|
|
|
pre_att = ( |
|
jnp.expand_dims(pre_att_1, axis=1) + |
|
jnp.expand_dims(pre_att_2, axis=2) + |
|
pre_att_e + |
|
jnp.expand_dims(pre_att_g, axis=(1, 2)) |
|
) |
|
|
|
pre_att = jnp.reshape( |
|
pre_att, |
|
pre_att.shape[:-1] + (self.nb_heads, self.mid_head_size) |
|
) |
|
|
|
pre_att = jnp.transpose(pre_att, (0, 3, 1, 2, 4)) |
|
|
|
|
|
|
|
logit_heads = [] |
|
for head in range(self.nb_heads): |
|
logit_heads.append( |
|
jnp.squeeze( |
|
a_heads[head](jax.nn.leaky_relu(pre_att[:, head])), |
|
axis=-1) |
|
) |
|
|
|
logits = jnp.stack(logit_heads, axis=1) |
|
|
|
coefs = jax.nn.softmax(logits + bias_mat, axis=-1) |
|
ret = jnp.matmul(coefs, values) |
|
ret = jnp.transpose(ret, (0, 2, 1, 3)) |
|
ret = jnp.reshape(ret, ret.shape[:-2] + (self.out_size,)) |
|
|
|
if self.residual: |
|
ret += skip(z) |
|
|
|
if self.activation is not None: |
|
ret = self.activation(ret) |
|
|
|
if self.use_ln: |
|
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) |
|
ret = ln(ret) |
|
|
|
return ret, None |
|
|
|
|
|
class GATv2Full(GATv2): |
|
"""Graph Attention Network v2 with full adjacency matrix.""" |
|
|
|
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, |
|
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: |
|
adj_mat = jnp.ones_like(adj_mat) |
|
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) |
|
|
|
|
|
def get_triplet_msgs(z, edge_fts, graph_fts, nb_triplet_fts): |
|
"""Triplet messages, as done by Dudzik and Velickovic (2022).""" |
|
t_1 = hk.Linear(nb_triplet_fts) |
|
t_2 = hk.Linear(nb_triplet_fts) |
|
t_3 = hk.Linear(nb_triplet_fts) |
|
t_e_1 = hk.Linear(nb_triplet_fts) |
|
t_e_2 = hk.Linear(nb_triplet_fts) |
|
t_e_3 = hk.Linear(nb_triplet_fts) |
|
t_g = hk.Linear(nb_triplet_fts) |
|
|
|
tri_1 = t_1(z) |
|
tri_2 = t_2(z) |
|
tri_3 = t_3(z) |
|
tri_e_1 = t_e_1(edge_fts) |
|
tri_e_2 = t_e_2(edge_fts) |
|
tri_e_3 = t_e_3(edge_fts) |
|
tri_g = t_g(graph_fts) |
|
|
|
return ( |
|
jnp.expand_dims(tri_1, axis=(2, 3)) + |
|
jnp.expand_dims(tri_2, axis=(1, 3)) + |
|
jnp.expand_dims(tri_3, axis=(1, 2)) + |
|
jnp.expand_dims(tri_e_1, axis=3) + |
|
jnp.expand_dims(tri_e_2, axis=2) + |
|
jnp.expand_dims(tri_e_3, axis=1) + |
|
jnp.expand_dims(tri_g, axis=(1, 2, 3)) |
|
) |
|
|
|
|
|
class PGN(Processor): |
|
"""Pointer Graph Networks (Veličković et al., NeurIPS 2020).""" |
|
|
|
def __init__( |
|
self, |
|
out_size: int, |
|
mid_size: Optional[int] = None, |
|
mid_act: Optional[_Fn] = None, |
|
activation: Optional[_Fn] = jax.nn.relu, |
|
reduction: _Fn = jnp.max, |
|
msgs_mlp_sizes: Optional[List[int]] = None, |
|
use_ln: bool = False, |
|
use_triplets: bool = False, |
|
nb_triplet_fts: int = 8, |
|
gated: bool = False, |
|
name: str = 'mpnn_aggr', |
|
): |
|
super().__init__(name=name) |
|
if mid_size is None: |
|
self.mid_size = out_size |
|
else: |
|
self.mid_size = mid_size |
|
self.out_size = out_size |
|
self.mid_act = mid_act |
|
self.activation = activation |
|
self.reduction = reduction |
|
self._msgs_mlp_sizes = msgs_mlp_sizes |
|
self.use_ln = use_ln |
|
self.use_triplets = use_triplets |
|
self.nb_triplet_fts = nb_triplet_fts |
|
self.gated = gated |
|
|
|
def __call__( |
|
self, |
|
node_fts: _Array, |
|
edge_fts: _Array, |
|
graph_fts: _Array, |
|
adj_mat: _Array, |
|
hidden: _Array, |
|
**unused_kwargs, |
|
) -> _Array: |
|
"""MPNN inference step.""" |
|
|
|
b, n, _ = node_fts.shape |
|
assert edge_fts.shape[:-1] == (b, n, n) |
|
assert graph_fts.shape[:-1] == (b,) |
|
assert adj_mat.shape == (b, n, n) |
|
|
|
z = jnp.concatenate([node_fts, hidden], axis=-1) |
|
m_1 = hk.Linear(self.mid_size) |
|
m_2 = hk.Linear(self.mid_size) |
|
m_e = hk.Linear(self.mid_size) |
|
m_g = hk.Linear(self.mid_size) |
|
|
|
o1 = hk.Linear(self.out_size) |
|
o2 = hk.Linear(self.out_size) |
|
|
|
msg_1 = m_1(z) |
|
msg_2 = m_2(z) |
|
msg_e = m_e(edge_fts) |
|
msg_g = m_g(graph_fts) |
|
|
|
tri_msgs = None |
|
|
|
if self.use_triplets: |
|
|
|
triplets = get_triplet_msgs(z, edge_fts, graph_fts, self.nb_triplet_fts) |
|
|
|
o3 = hk.Linear(self.out_size) |
|
tri_msgs = o3(jnp.max(triplets, axis=1)) |
|
|
|
if self.activation is not None: |
|
tri_msgs = self.activation(tri_msgs) |
|
|
|
msgs = ( |
|
jnp.expand_dims(msg_1, axis=1) + jnp.expand_dims(msg_2, axis=2) + |
|
msg_e + jnp.expand_dims(msg_g, axis=(1, 2))) |
|
|
|
if self._msgs_mlp_sizes is not None: |
|
msgs = hk.nets.MLP(self._msgs_mlp_sizes)(jax.nn.relu(msgs)) |
|
|
|
if self.mid_act is not None: |
|
msgs = self.mid_act(msgs) |
|
|
|
if self.reduction == jnp.mean: |
|
msgs = jnp.sum(msgs * jnp.expand_dims(adj_mat, -1), axis=1) |
|
msgs = msgs / jnp.sum(adj_mat, axis=-1, keepdims=True) |
|
elif self.reduction == jnp.max: |
|
maxarg = jnp.where(jnp.expand_dims(adj_mat, -1), |
|
msgs, |
|
-BIG_NUMBER) |
|
msgs = jnp.max(maxarg, axis=1) |
|
else: |
|
msgs = self.reduction(msgs * jnp.expand_dims(adj_mat, -1), axis=1) |
|
|
|
h_1 = o1(z) |
|
h_2 = o2(msgs) |
|
|
|
ret = h_1 + h_2 |
|
|
|
if self.activation is not None: |
|
ret = self.activation(ret) |
|
|
|
if self.use_ln: |
|
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) |
|
ret = ln(ret) |
|
|
|
if self.gated: |
|
gate1 = hk.Linear(self.out_size) |
|
gate2 = hk.Linear(self.out_size) |
|
gate3 = hk.Linear(self.out_size, b_init=hk.initializers.Constant(-3)) |
|
gate = jax.nn.sigmoid(gate3(jax.nn.relu(gate1(z) + gate2(msgs)))) |
|
ret = ret * gate + hidden * (1-gate) |
|
|
|
return ret, tri_msgs |
|
|
|
|
|
class DeepSets(PGN): |
|
"""Deep Sets (Zaheer et al., NeurIPS 2017).""" |
|
|
|
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, |
|
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: |
|
assert adj_mat.ndim == 3 |
|
adj_mat = jnp.ones_like(adj_mat) * jnp.eye(adj_mat.shape[-1]) |
|
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) |
|
|
|
|
|
class MPNN(PGN): |
|
"""Message-Passing Neural Network (Gilmer et al., ICML 2017).""" |
|
|
|
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, |
|
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: |
|
adj_mat = jnp.ones_like(adj_mat) |
|
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) |
|
|
|
|
|
class PGNMask(PGN): |
|
"""Masked Pointer Graph Networks (Veličković et al., NeurIPS 2020).""" |
|
|
|
@property |
|
def inf_bias(self): |
|
return True |
|
|
|
@property |
|
def inf_bias_edge(self): |
|
return True |
|
|
|
|
|
class MemNetMasked(Processor): |
|
"""Implementation of End-to-End Memory Networks. |
|
|
|
Inspired by the description in https://arxiv.org/abs/1503.08895. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
vocab_size: int, |
|
sentence_size: int, |
|
linear_output_size: int, |
|
embedding_size: int = 16, |
|
memory_size: Optional[int] = 128, |
|
num_hops: int = 1, |
|
nonlin: Callable[[Any], Any] = jax.nn.relu, |
|
apply_embeddings: bool = True, |
|
init_func: hk.initializers.Initializer = jnp.zeros, |
|
use_ln: bool = False, |
|
name: str = 'memnet') -> None: |
|
"""Constructor. |
|
|
|
Args: |
|
vocab_size: the number of words in the dictionary (each story, query and |
|
answer come contain symbols coming from this dictionary). |
|
sentence_size: the dimensionality of each memory. |
|
linear_output_size: the dimensionality of the output of the last layer |
|
of the model. |
|
embedding_size: the dimensionality of the latent space to where all |
|
memories are projected. |
|
memory_size: the number of memories provided. |
|
num_hops: the number of layers in the model. |
|
nonlin: non-linear transformation applied at the end of each layer. |
|
apply_embeddings: flag whether to aply embeddings. |
|
init_func: initialization function for the biases. |
|
use_ln: whether to use layer normalisation in the model. |
|
name: the name of the model. |
|
""" |
|
super().__init__(name=name) |
|
self._vocab_size = vocab_size |
|
self._embedding_size = embedding_size |
|
self._sentence_size = sentence_size |
|
self._memory_size = memory_size |
|
self._linear_output_size = linear_output_size |
|
self._num_hops = num_hops |
|
self._nonlin = nonlin |
|
self._apply_embeddings = apply_embeddings |
|
self._init_func = init_func |
|
self._use_ln = use_ln |
|
|
|
self._encodings = _position_encoding(sentence_size, embedding_size) |
|
|
|
def __call__( |
|
self, |
|
node_fts: _Array, |
|
edge_fts: _Array, |
|
graph_fts: _Array, |
|
adj_mat: _Array, |
|
hidden: _Array, |
|
**unused_kwargs, |
|
) -> _Array: |
|
"""MemNet inference step.""" |
|
|
|
del hidden |
|
node_and_graph_fts = jnp.concatenate([node_fts, graph_fts[:, None]], |
|
axis=1) |
|
edge_fts_padded = jnp.pad(edge_fts * adj_mat[..., None], |
|
((0, 0), (0, 1), (0, 1), (0, 0))) |
|
nxt_hidden = jax.vmap(self._apply, (1), 1)(node_and_graph_fts, |
|
edge_fts_padded) |
|
|
|
|
|
nxt_hidden = nxt_hidden[:, :-1] + nxt_hidden[:, -1:] |
|
return nxt_hidden, None |
|
|
|
def _apply(self, queries: _Array, stories: _Array) -> _Array: |
|
"""Apply Memory Network to the queries and stories. |
|
|
|
Args: |
|
queries: Tensor of shape [batch_size, sentence_size]. |
|
stories: Tensor of shape [batch_size, memory_size, sentence_size]. |
|
|
|
Returns: |
|
Tensor of shape [batch_size, vocab_size]. |
|
""" |
|
if self._apply_embeddings: |
|
query_biases = hk.get_parameter( |
|
'query_biases', |
|
shape=[self._vocab_size - 1, self._embedding_size], |
|
init=self._init_func) |
|
stories_biases = hk.get_parameter( |
|
'stories_biases', |
|
shape=[self._vocab_size - 1, self._embedding_size], |
|
init=self._init_func) |
|
memory_biases = hk.get_parameter( |
|
'memory_contents', |
|
shape=[self._memory_size, self._embedding_size], |
|
init=self._init_func) |
|
output_biases = hk.get_parameter( |
|
'output_biases', |
|
shape=[self._vocab_size - 1, self._embedding_size], |
|
init=self._init_func) |
|
|
|
nil_word_slot = jnp.zeros([1, self._embedding_size]) |
|
|
|
|
|
if self._apply_embeddings: |
|
stories_biases = jnp.concatenate([stories_biases, nil_word_slot], axis=0) |
|
memory_embeddings = jnp.take( |
|
stories_biases, stories.reshape([-1]).astype(jnp.int32), |
|
axis=0).reshape(list(stories.shape) + [self._embedding_size]) |
|
memory_embeddings = jnp.pad( |
|
memory_embeddings, |
|
((0, 0), (0, self._memory_size - jnp.shape(memory_embeddings)[1]), |
|
(0, 0), (0, 0))) |
|
memory = jnp.sum(memory_embeddings * self._encodings, 2) + memory_biases |
|
else: |
|
memory = stories |
|
|
|
|
|
|
|
|
|
if self._apply_embeddings: |
|
query_biases = jnp.concatenate([query_biases, nil_word_slot], axis=0) |
|
query_embeddings = jnp.take( |
|
query_biases, queries.reshape([-1]).astype(jnp.int32), |
|
axis=0).reshape(list(queries.shape) + [self._embedding_size]) |
|
|
|
query_input_embedding = jnp.sum(query_embeddings * self._encodings, 1) |
|
else: |
|
query_input_embedding = queries |
|
|
|
|
|
if self._apply_embeddings: |
|
output_biases = jnp.concatenate([output_biases, nil_word_slot], axis=0) |
|
output_embeddings = jnp.take( |
|
output_biases, stories.reshape([-1]).astype(jnp.int32), |
|
axis=0).reshape(list(stories.shape) + [self._embedding_size]) |
|
output_embeddings = jnp.pad( |
|
output_embeddings, |
|
((0, 0), (0, self._memory_size - jnp.shape(output_embeddings)[1]), |
|
(0, 0), (0, 0))) |
|
output = jnp.sum(output_embeddings * self._encodings, 2) |
|
else: |
|
output = stories |
|
|
|
intermediate_linear = hk.Linear(self._embedding_size, with_bias=False) |
|
|
|
|
|
output_linear = hk.Linear(self._linear_output_size, with_bias=False) |
|
|
|
for hop_number in range(self._num_hops): |
|
query_input_embedding_transposed = jnp.transpose( |
|
jnp.expand_dims(query_input_embedding, -1), [0, 2, 1]) |
|
|
|
|
|
probs = jax.nn.softmax( |
|
jnp.sum(memory * query_input_embedding_transposed, 2)) |
|
|
|
|
|
transposed_probs = jnp.transpose(jnp.expand_dims(probs, -1), [0, 2, 1]) |
|
transposed_output_embeddings = jnp.transpose(output, [0, 2, 1]) |
|
|
|
|
|
layer_output = jnp.sum(transposed_output_embeddings * transposed_probs, 2) |
|
|
|
|
|
if hop_number == self._num_hops - 1: |
|
|
|
|
|
output_layer = output_linear(query_input_embedding + layer_output) |
|
else: |
|
output_layer = intermediate_linear(query_input_embedding + layer_output) |
|
|
|
query_input_embedding = output_layer |
|
if self._nonlin: |
|
output_layer = self._nonlin(output_layer) |
|
|
|
|
|
ret = hk.Linear(self._vocab_size, with_bias=False)(output_layer) |
|
|
|
if self._use_ln: |
|
ln = hk.LayerNorm(axis=-1, create_scale=True, create_offset=True) |
|
ret = ln(ret) |
|
|
|
return ret |
|
|
|
|
|
class MemNetFull(MemNetMasked): |
|
"""Memory Networks with full adjacency matrix.""" |
|
|
|
def __call__(self, node_fts: _Array, edge_fts: _Array, graph_fts: _Array, |
|
adj_mat: _Array, hidden: _Array, **unused_kwargs) -> _Array: |
|
adj_mat = jnp.ones_like(adj_mat) |
|
return super().__call__(node_fts, edge_fts, graph_fts, adj_mat, hidden) |
|
|
|
|
|
ProcessorFactory = Callable[[int], Processor] |
|
|
|
|
|
def get_processor_factory(kind: str, |
|
use_ln: bool, |
|
nb_triplet_fts: int, |
|
nb_heads: Optional[int] = None) -> ProcessorFactory: |
|
"""Returns a processor factory. |
|
|
|
Args: |
|
kind: One of the available types of processor. |
|
use_ln: Whether the processor passes the output through a layernorm layer. |
|
nb_triplet_fts: How many triplet features to compute. |
|
nb_heads: Number of attention heads for GAT processors. |
|
Returns: |
|
A callable that takes an `out_size` parameter (equal to the hidden |
|
dimension of the network) and returns a processor instance. |
|
""" |
|
def _factory(out_size: int): |
|
if kind == 'deepsets': |
|
processor = DeepSets( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=False, |
|
nb_triplet_fts=0 |
|
) |
|
elif kind == 'gat': |
|
processor = GAT( |
|
out_size=out_size, |
|
nb_heads=nb_heads, |
|
use_ln=use_ln, |
|
) |
|
elif kind == 'gat_full': |
|
processor = GATFull( |
|
out_size=out_size, |
|
nb_heads=nb_heads, |
|
use_ln=use_ln |
|
) |
|
elif kind == 'gatv2': |
|
processor = GATv2( |
|
out_size=out_size, |
|
nb_heads=nb_heads, |
|
use_ln=use_ln |
|
) |
|
elif kind == 'gatv2_full': |
|
processor = GATv2Full( |
|
out_size=out_size, |
|
nb_heads=nb_heads, |
|
use_ln=use_ln |
|
) |
|
elif kind == 'memnet_full': |
|
processor = MemNetFull( |
|
vocab_size=out_size, |
|
sentence_size=out_size, |
|
linear_output_size=out_size, |
|
) |
|
elif kind == 'memnet_masked': |
|
processor = MemNetMasked( |
|
vocab_size=out_size, |
|
sentence_size=out_size, |
|
linear_output_size=out_size, |
|
) |
|
elif kind == 'mpnn': |
|
processor = MPNN( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=False, |
|
nb_triplet_fts=0, |
|
) |
|
elif kind == 'pgn': |
|
processor = PGN( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=False, |
|
nb_triplet_fts=0, |
|
) |
|
elif kind == 'pgn_mask': |
|
processor = PGNMask( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=False, |
|
nb_triplet_fts=0, |
|
) |
|
elif kind == 'triplet_mpnn': |
|
processor = MPNN( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=True, |
|
nb_triplet_fts=nb_triplet_fts, |
|
) |
|
elif kind == 'triplet_pgn': |
|
processor = PGN( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=True, |
|
nb_triplet_fts=nb_triplet_fts, |
|
) |
|
elif kind == 'triplet_pgn_mask': |
|
processor = PGNMask( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=True, |
|
nb_triplet_fts=nb_triplet_fts, |
|
) |
|
elif kind == 'gpgn': |
|
processor = PGN( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=False, |
|
nb_triplet_fts=nb_triplet_fts, |
|
gated=True, |
|
) |
|
elif kind == 'gpgn_mask': |
|
processor = PGNMask( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=False, |
|
nb_triplet_fts=nb_triplet_fts, |
|
gated=True, |
|
) |
|
elif kind == 'gmpnn': |
|
processor = MPNN( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=False, |
|
nb_triplet_fts=nb_triplet_fts, |
|
gated=True, |
|
) |
|
elif kind == 'triplet_gpgn': |
|
processor = PGN( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=True, |
|
nb_triplet_fts=nb_triplet_fts, |
|
gated=True, |
|
) |
|
elif kind == 'triplet_gpgn_mask': |
|
processor = PGNMask( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=True, |
|
nb_triplet_fts=nb_triplet_fts, |
|
gated=True, |
|
) |
|
elif kind == 'triplet_gmpnn': |
|
processor = MPNN( |
|
out_size=out_size, |
|
msgs_mlp_sizes=[out_size, out_size], |
|
use_ln=use_ln, |
|
use_triplets=True, |
|
nb_triplet_fts=nb_triplet_fts, |
|
gated=True, |
|
) |
|
else: |
|
raise ValueError('Unexpected processor kind ' + kind) |
|
|
|
return processor |
|
|
|
return _factory |
|
|
|
|
|
def _position_encoding(sentence_size: int, embedding_size: int) -> np.ndarray: |
|
"""Position Encoding described in section 4.1 [1].""" |
|
encoding = np.ones((embedding_size, sentence_size), dtype=np.float32) |
|
ls = sentence_size + 1 |
|
le = embedding_size + 1 |
|
for i in range(1, le): |
|
for j in range(1, ls): |
|
encoding[i - 1, j - 1] = (i - (le - 1) / 2) * (j - (ls - 1) / 2) |
|
encoding = 1 + 4 * encoding / embedding_size / sentence_size |
|
return np.transpose(encoding) |
|
|