|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Encoder utilities.""" |
|
|
|
import functools |
|
import chex |
|
from clrs._src import probing |
|
from clrs._src import specs |
|
import haiku as hk |
|
import jax.numpy as jnp |
|
|
|
_Array = chex.Array |
|
_DataPoint = probing.DataPoint |
|
_Location = specs.Location |
|
_Spec = specs.Spec |
|
_Stage = specs.Stage |
|
_Type = specs.Type |
|
|
|
|
|
def construct_encoders(stage: str, loc: str, t: str, |
|
hidden_dim: int, init: str, name: str): |
|
"""Constructs encoders.""" |
|
if init == 'xavier_on_scalars' and stage == _Stage.HINT and t == _Type.SCALAR: |
|
initialiser = hk.initializers.TruncatedNormal( |
|
stddev=1.0 / jnp.sqrt(hidden_dim)) |
|
elif init in ['default', 'xavier_on_scalars']: |
|
initialiser = None |
|
else: |
|
raise ValueError(f'Encoder initialiser {init} not supported.') |
|
linear = functools.partial( |
|
hk.Linear, |
|
w_init=initialiser, |
|
name=f'{name}_enc_linear') |
|
encoders = [linear(hidden_dim)] |
|
if loc == _Location.EDGE and t == _Type.POINTER: |
|
|
|
encoders.append(linear(hidden_dim)) |
|
|
|
return encoders |
|
|
|
|
|
def preprocess(dp: _DataPoint, nb_nodes: int) -> _DataPoint: |
|
"""Pre-process data point. |
|
|
|
Make sure that the data is ready to be encoded into features. |
|
If the data is of POINTER type, we expand the compressed index representation |
|
to a full one-hot. But if the data is a SOFT_POINTER, the representation |
|
is already expanded and we just overwrite the type as POINTER so that |
|
it is treated as such for encoding. |
|
|
|
Args: |
|
dp: A DataPoint to prepare for encoding. |
|
nb_nodes: Number of nodes in the graph, necessary to expand pointers to |
|
the right dimension. |
|
Returns: |
|
The datapoint, with data and possibly type modified. |
|
""" |
|
new_type = dp.type_ |
|
if dp.type_ == _Type.POINTER: |
|
data = hk.one_hot(dp.data, nb_nodes) |
|
else: |
|
data = dp.data.astype(jnp.float32) |
|
if dp.type_ == _Type.SOFT_POINTER: |
|
new_type = _Type.POINTER |
|
dp = probing.DataPoint( |
|
name=dp.name, location=dp.location, type_=new_type, data=data) |
|
|
|
return dp |
|
|
|
|
|
def accum_adj_mat(dp: _DataPoint, adj_mat: _Array) -> _Array: |
|
"""Accumulates adjacency matrix.""" |
|
if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER, |
|
_Type.PERMUTATION_POINTER]: |
|
adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.5) |
|
elif dp.location == _Location.EDGE and dp.type_ == _Type.MASK: |
|
adj_mat += ((dp.data + jnp.transpose(dp.data, (0, 2, 1))) > 0.0) |
|
|
|
return (adj_mat > 0.).astype('float32') |
|
|
|
|
|
def accum_edge_fts(encoders, dp: _DataPoint, edge_fts: _Array) -> _Array: |
|
"""Encodes and accumulates edge features.""" |
|
if dp.location == _Location.NODE and dp.type_ in [_Type.POINTER, |
|
_Type.PERMUTATION_POINTER]: |
|
encoding = _encode_inputs(encoders, dp) |
|
edge_fts += encoding |
|
|
|
elif dp.location == _Location.EDGE: |
|
encoding = _encode_inputs(encoders, dp) |
|
if dp.type_ == _Type.POINTER: |
|
|
|
encoding_2 = encoders[1](jnp.expand_dims(dp.data, -1)) |
|
edge_fts += jnp.mean(encoding, axis=1) + jnp.mean(encoding_2, axis=2) |
|
else: |
|
edge_fts += encoding |
|
|
|
return edge_fts |
|
|
|
|
|
def accum_node_fts(encoders, dp: _DataPoint, node_fts: _Array) -> _Array: |
|
"""Encodes and accumulates node features.""" |
|
is_pointer = (dp.type_ in [_Type.POINTER, _Type.PERMUTATION_POINTER]) |
|
if ((dp.location == _Location.NODE and not is_pointer) or |
|
(dp.location == _Location.GRAPH and dp.type_ == _Type.POINTER)): |
|
encoding = _encode_inputs(encoders, dp) |
|
node_fts += encoding |
|
|
|
return node_fts |
|
|
|
|
|
def accum_graph_fts(encoders, dp: _DataPoint, |
|
graph_fts: _Array) -> _Array: |
|
"""Encodes and accumulates graph features.""" |
|
if dp.location == _Location.GRAPH and dp.type_ != _Type.POINTER: |
|
encoding = _encode_inputs(encoders, dp) |
|
graph_fts += encoding |
|
|
|
return graph_fts |
|
|
|
|
|
def _encode_inputs(encoders, dp: _DataPoint) -> _Array: |
|
if dp.type_ == _Type.CATEGORICAL: |
|
encoding = encoders[0](dp.data) |
|
else: |
|
encoding = encoders[0](jnp.expand_dims(dp.data, -1)) |
|
return encoding |
|
|