|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Probing utilities. |
|
|
|
The dataflow for an algorithm is represented by `(stage, loc, type, data)` |
|
"probes" that are valid under that algorithm's spec (see `specs.py`). |
|
|
|
When constructing probes, it is convenient to represent these fields in a nested |
|
format (`ProbesDict`) to facilate efficient contest-based look-up. |
|
|
|
""" |
|
|
|
import functools |
|
from typing import Dict, List, Tuple, Union |
|
|
|
import attr |
|
from clrs._src import specs |
|
import jax |
|
import jax.numpy as jnp |
|
import numpy as np |
|
import tensorflow as tf |
|
|
|
|
|
_Location = specs.Location |
|
_Stage = specs.Stage |
|
_Type = specs.Type |
|
_OutputClass = specs.OutputClass |
|
|
|
_Array = np.ndarray |
|
_Data = Union[_Array, List[_Array]] |
|
_DataOrType = Union[_Data, str] |
|
|
|
ProbesDict = Dict[ |
|
str, Dict[str, Dict[str, Dict[str, _DataOrType]]]] |
|
|
|
|
|
def _convert_to_str(element): |
|
if isinstance(element, tf.Tensor): |
|
return element.numpy().decode('utf-8') |
|
elif isinstance(element, (np.ndarray, bytes)): |
|
return element.decode('utf-8') |
|
else: |
|
return element |
|
|
|
|
|
|
|
|
|
@jax.tree_util.register_pytree_node_class |
|
@attr.define |
|
class DataPoint: |
|
"""Describes a data point.""" |
|
|
|
_name: str |
|
_location: str |
|
_type_: str |
|
data: _Array |
|
|
|
@property |
|
def name(self): |
|
return _convert_to_str(self._name) |
|
|
|
@property |
|
def location(self): |
|
return _convert_to_str(self._location) |
|
|
|
@property |
|
def type_(self): |
|
return _convert_to_str(self._type_) |
|
|
|
def __repr__(self): |
|
s = f'DataPoint(name="{self.name}",\tlocation={self.location},\t' |
|
return s + f'type={self.type_},\tdata=Array{self.data.shape})' |
|
|
|
def tree_flatten(self): |
|
data = (self.data,) |
|
meta = (self.name, self.location, self.type_) |
|
return data, meta |
|
|
|
@classmethod |
|
def tree_unflatten(cls, meta, data): |
|
name, location, type_ = meta |
|
subdata, = data |
|
return DataPoint(name, location, type_, subdata) |
|
|
|
|
|
class ProbeError(Exception): |
|
pass |
|
|
|
|
|
def initialize(spec: specs.Spec) -> ProbesDict: |
|
"""Initializes an empty `ProbesDict` corresponding with the provided spec.""" |
|
probes = dict() |
|
for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]: |
|
probes[stage] = {} |
|
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: |
|
probes[stage][loc] = {} |
|
|
|
for name in spec: |
|
stage, loc, t = spec[name] |
|
probes[stage][loc][name] = {} |
|
probes[stage][loc][name]['data'] = [] |
|
probes[stage][loc][name]['type_'] = t |
|
|
|
|
|
return probes |
|
|
|
|
|
def push(probes: ProbesDict, stage: str, next_probe): |
|
"""Pushes a probe into an existing `ProbesDict`.""" |
|
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: |
|
for name in probes[stage][loc]: |
|
if name not in next_probe: |
|
raise ProbeError(f'Missing probe for {name}.') |
|
if isinstance(probes[stage][loc][name]['data'], _Array): |
|
raise ProbeError('Attemping to push to finalized `ProbesDict`.') |
|
|
|
|
|
probes[stage][loc][name]['data'].append(next_probe[name]) |
|
|
|
|
|
def finalize(probes: ProbesDict): |
|
"""Finalizes a `ProbesDict` by stacking/squeezing `data` field.""" |
|
for stage in [_Stage.INPUT, _Stage.OUTPUT, _Stage.HINT]: |
|
for loc in [_Location.NODE, _Location.EDGE, _Location.GRAPH]: |
|
for name in probes[stage][loc]: |
|
if isinstance(probes[stage][loc][name]['data'], _Array): |
|
raise ProbeError('Attemping to re-finalize a finalized `ProbesDict`.') |
|
if stage == _Stage.HINT: |
|
|
|
probes[stage][loc][name]['data'] = np.stack( |
|
probes[stage][loc][name]['data']) |
|
else: |
|
|
|
probes[stage][loc][name]['data'] = np.squeeze( |
|
np.array(probes[stage][loc][name]['data'])) |
|
|
|
|
|
def split_stages( |
|
probes: ProbesDict, |
|
spec: specs.Spec, |
|
) -> Tuple[List[DataPoint], List[DataPoint], List[DataPoint]]: |
|
"""Splits contents of `ProbesDict` into `DataPoint`s by stage.""" |
|
|
|
inputs = [] |
|
outputs = [] |
|
hints = [] |
|
|
|
for name in spec: |
|
stage, loc, t = spec[name] |
|
|
|
if stage not in probes: |
|
raise ProbeError(f'Missing stage {stage}.') |
|
if loc not in probes[stage]: |
|
raise ProbeError(f'Missing location {loc}.') |
|
if name not in probes[stage][loc]: |
|
raise ProbeError(f'Missing probe {name}.') |
|
if 'type_' not in probes[stage][loc][name]: |
|
raise ProbeError(f'Probe {name} missing attribute `type_`.') |
|
if 'data' not in probes[stage][loc][name]: |
|
raise ProbeError(f'Probe {name} missing attribute `data`.') |
|
if t != probes[stage][loc][name]['type_']: |
|
raise ProbeError(f'Probe {name} of incorrect type {t}.') |
|
|
|
data = probes[stage][loc][name]['data'] |
|
if not isinstance(probes[stage][loc][name]['data'], _Array): |
|
raise ProbeError((f'Invalid `data` for probe "{name}". ' + |
|
'Did you forget to call `probing.finalize`?')) |
|
|
|
if t in [_Type.MASK, _Type.MASK_ONE, _Type.CATEGORICAL]: |
|
|
|
if not ((data == 0) | (data == 1) | (data == -1)).all(): |
|
raise ProbeError(f'0|1|-1 `data` for probe "{name}"') |
|
|
|
if t in [_Type.MASK_ONE, _Type.CATEGORICAL |
|
] and not np.all(np.sum(np.abs(data), -1) == 1): |
|
raise ProbeError(f'Expected one-hot `data` for probe "{name}"') |
|
|
|
dim_to_expand = 1 if stage == _Stage.HINT else 0 |
|
data_point = DataPoint(name=name, location=loc, type_=t, |
|
data=np.expand_dims(data, dim_to_expand)) |
|
|
|
if stage == _Stage.INPUT: |
|
inputs.append(data_point) |
|
elif stage == _Stage.OUTPUT: |
|
outputs.append(data_point) |
|
else: |
|
hints.append(data_point) |
|
|
|
return inputs, outputs, hints |
|
|
|
|
|
|
|
|
|
|
|
def array(A_pos: np.ndarray) -> np.ndarray: |
|
"""Constructs an `array` probe.""" |
|
probe = np.arange(A_pos.shape[0]) |
|
for i in range(1, A_pos.shape[0]): |
|
probe[A_pos[i]] = A_pos[i - 1] |
|
return probe |
|
|
|
|
|
def array_cat(A: np.ndarray, n: int) -> np.ndarray: |
|
"""Constructs an `array_cat` probe.""" |
|
assert n > 0 |
|
probe = np.zeros((A.shape[0], n)) |
|
for i in range(A.shape[0]): |
|
probe[i, A[i]] = 1 |
|
return probe |
|
|
|
|
|
def heap(A_pos: np.ndarray, heap_size: int) -> np.ndarray: |
|
"""Constructs a `heap` probe.""" |
|
assert heap_size > 0 |
|
probe = np.arange(A_pos.shape[0]) |
|
for i in range(1, heap_size): |
|
probe[A_pos[i]] = A_pos[(i - 1) // 2] |
|
return probe |
|
|
|
|
|
def graph(A: np.ndarray) -> np.ndarray: |
|
"""Constructs a `graph` probe.""" |
|
probe = (A != 0) * 1.0 |
|
probe = ((A + np.eye(A.shape[0])) != 0) * 1.0 |
|
return probe |
|
|
|
|
|
def mask_one(i: int, n: int) -> np.ndarray: |
|
"""Constructs a `mask_one` probe.""" |
|
assert n > i |
|
probe = np.zeros(n) |
|
probe[i] = 1 |
|
return probe |
|
|
|
|
|
def strings_id(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: |
|
"""Constructs a `strings_id` probe.""" |
|
probe_T = np.zeros(T_pos.shape[0]) |
|
probe_P = np.ones(P_pos.shape[0]) |
|
return np.concatenate([probe_T, probe_P]) |
|
|
|
|
|
def strings_pair(pair_probe: np.ndarray) -> np.ndarray: |
|
"""Constructs a `strings_pair` probe.""" |
|
n = pair_probe.shape[0] |
|
m = pair_probe.shape[1] |
|
probe_ret = np.zeros((n + m, n + m)) |
|
for i in range(0, n): |
|
for j in range(0, m): |
|
probe_ret[i, j + n] = pair_probe[i, j] |
|
return probe_ret |
|
|
|
|
|
def strings_pair_cat(pair_probe: np.ndarray, nb_classes: int) -> np.ndarray: |
|
"""Constructs a `strings_pair_cat` probe.""" |
|
assert nb_classes > 0 |
|
n = pair_probe.shape[0] |
|
m = pair_probe.shape[1] |
|
|
|
|
|
probe_ret = np.zeros((n + m, n + m, nb_classes + 1)) |
|
for i in range(0, n): |
|
for j in range(0, m): |
|
probe_ret[i, j + n, int(pair_probe[i, j])] = _OutputClass.POSITIVE |
|
|
|
|
|
for i_1 in range(0, n): |
|
for i_2 in range(0, n): |
|
probe_ret[i_1, i_2, nb_classes] = _OutputClass.MASKED |
|
for j_1 in range(0, m): |
|
for x in range(0, n + m): |
|
probe_ret[j_1 + n, x, nb_classes] = _OutputClass.MASKED |
|
return probe_ret |
|
|
|
|
|
def strings_pi(T_pos: np.ndarray, P_pos: np.ndarray, |
|
pi: np.ndarray) -> np.ndarray: |
|
"""Constructs a `strings_pi` probe.""" |
|
probe = np.arange(T_pos.shape[0] + P_pos.shape[0]) |
|
for j in range(P_pos.shape[0]): |
|
probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + pi[P_pos[j]] |
|
return probe |
|
|
|
|
|
def strings_pos(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: |
|
"""Constructs a `strings_pos` probe.""" |
|
probe_T = np.copy(T_pos) * 1.0 / T_pos.shape[0] |
|
probe_P = np.copy(P_pos) * 1.0 / P_pos.shape[0] |
|
return np.concatenate([probe_T, probe_P]) |
|
|
|
|
|
def strings_pred(T_pos: np.ndarray, P_pos: np.ndarray) -> np.ndarray: |
|
"""Constructs a `strings_pred` probe.""" |
|
probe = np.arange(T_pos.shape[0] + P_pos.shape[0]) |
|
for i in range(1, T_pos.shape[0]): |
|
probe[T_pos[i]] = T_pos[i - 1] |
|
for j in range(1, P_pos.shape[0]): |
|
probe[T_pos.shape[0] + P_pos[j]] = T_pos.shape[0] + P_pos[j - 1] |
|
return probe |
|
|
|
|
|
@functools.partial(jnp.vectorize, signature='(n)->(n,n),(n)') |
|
def predecessor_to_cyclic_predecessor_and_first( |
|
pointers: jnp.ndarray) -> Tuple[jnp.ndarray, jnp.ndarray]: |
|
"""Converts predecessor pointers to cyclic predecessor + first node mask. |
|
|
|
This function assumes that the pointers represent a linear order of the nodes |
|
(akin to a linked list), where each node points to its predecessor and the |
|
first node points to itself. It returns the same pointers, except that |
|
the first node points to the last, and a mask_one marking the first node. |
|
|
|
Example: |
|
``` |
|
pointers = [2, 1, 1] |
|
P = [[0, 0, 1], |
|
[1, 0, 0], |
|
[0, 1, 0]], |
|
M = [0, 1, 0] |
|
``` |
|
|
|
Args: |
|
pointers: array of shape [N] containing pointers. The pointers are assumed |
|
to describe a linear order such that `pointers[i]` is the predecessor |
|
of node `i`. |
|
|
|
Returns: |
|
Permutation pointers `P` of shape [N] and one-hot vector `M` of shape [N]. |
|
""" |
|
nb_nodes = pointers.shape[-1] |
|
pointers_one_hot = jax.nn.one_hot(pointers, nb_nodes) |
|
|
|
last = pointers_one_hot.sum(-2).argmin() |
|
|
|
first = pointers_one_hot.diagonal().argmax() |
|
mask = jax.nn.one_hot(first, nb_nodes) |
|
pointers_one_hot += mask[..., None] * jax.nn.one_hot(last, nb_nodes) |
|
pointers_one_hot -= mask[..., None] * mask |
|
return pointers_one_hot, mask |
|
|