Spaces:
Running
Running
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
# ============================================================================== | |
"""Sampling utilities.""" | |
import abc | |
import collections | |
import inspect | |
import types | |
from typing import Any, Callable, List, Optional, Tuple | |
from absl import logging | |
from clrs._src import algorithms | |
from clrs._src import probing | |
from clrs._src import specs | |
import jax | |
import numpy as np | |
_Array = np.ndarray | |
_DataPoint = probing.DataPoint | |
Trajectory = List[_DataPoint] | |
Trajectories = List[Trajectory] | |
Algorithm = Callable[..., Any] | |
Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths']) | |
FeaturesChunked = collections.namedtuple( | |
'Features', ['inputs', 'hints', 'is_first', 'is_last']) | |
Feedback = collections.namedtuple('Feedback', ['features', 'outputs']) | |
# CLRS-30 baseline spec. | |
CLRS30 = types.MappingProxyType({ | |
'train': { | |
'num_samples': 1000, | |
'length': 16, | |
'seed': 1, | |
}, | |
'val': { | |
'num_samples': 32, | |
'length': 16, | |
'seed': 2, | |
}, | |
'test': { | |
'num_samples': 32, | |
'length': 64, | |
'seed': 3, | |
}, | |
}) | |
class Sampler(abc.ABC): | |
"""Sampler abstract base class.""" | |
def __init__( | |
self, | |
algorithm: Algorithm, | |
spec: specs.Spec, | |
num_samples: int, | |
*args, | |
seed: Optional[int] = None, | |
**kwargs, | |
): | |
"""Initializes a `Sampler`. | |
Args: | |
algorithm: The algorithm to sample from | |
spec: The algorithm spec. | |
num_samples: Number of algorithm unrolls to sample. If positive, all the | |
samples will be generated in the constructor, and at each call | |
of the `next` method a batch will be randomly selected among them. | |
If -1, samples are generated on the fly with each call to `next`. | |
*args: Algorithm args. | |
seed: RNG seed. | |
**kwargs: Algorithm kwargs. | |
""" | |
# Use `RandomState` to ensure deterministic sampling across Numpy versions. | |
self._rng = np.random.RandomState(seed) | |
self._spec = spec | |
self._num_samples = num_samples | |
self._algorithm = algorithm | |
self._args = args | |
self._kwargs = kwargs | |
if num_samples < 0: | |
logging.warning('Sampling dataset on-the-fly, unlimited samples.') | |
# Just get an initial estimate of max hint length | |
self.max_steps = -1 | |
for _ in range(1000): | |
data = self._sample_data(*args, **kwargs) | |
_, probes = algorithm(*data) | |
_, _, hint = probing.split_stages(probes, spec) | |
for dp in hint: | |
assert dp.data.shape[1] == 1 # batching axis | |
if dp.data.shape[0] > self.max_steps: | |
self.max_steps = dp.data.shape[0] | |
else: | |
logging.info('Creating a dataset with %i samples.', num_samples) | |
(self._inputs, self._outputs, self._hints, | |
self._lengths) = self._make_batch(num_samples, spec, 0, algorithm, *args, | |
**kwargs) | |
def _make_batch(self, num_samples: int, spec: specs.Spec, min_length: int, | |
algorithm: Algorithm, *args, **kwargs): | |
"""Generate a batch of data.""" | |
inputs = [] | |
outputs = [] | |
hints = [] | |
for _ in range(num_samples): | |
data = self._sample_data(*args, **kwargs) | |
_, probes = algorithm(*data) | |
inp, outp, hint = probing.split_stages(probes, spec) | |
inputs.append(inp) | |
outputs.append(outp) | |
hints.append(hint) | |
if len(hints) % 1000 == 0: | |
logging.info('%i samples created', len(hints)) | |
# Batch and pad trajectories to max(T). | |
inputs = _batch_io(inputs) | |
outputs = _batch_io(outputs) | |
hints, lengths = _batch_hints(hints, min_length) | |
return inputs, outputs, hints, lengths | |
def next(self, batch_size: Optional[int] = None) -> Feedback: | |
"""Subsamples trajectories from the pre-generated dataset. | |
Args: | |
batch_size: Optional batch size. If `None`, returns entire dataset. | |
Returns: | |
Subsampled trajectories. | |
""" | |
if batch_size: | |
if self._num_samples < 0: # generate on the fly | |
inputs, outputs, hints, lengths = self._make_batch( | |
batch_size, self._spec, self.max_steps, | |
self._algorithm, *self._args, **self._kwargs) | |
if hints[0].data.shape[0] > self.max_steps: | |
logging.warning('Increasing hint lengh from %i to %i', | |
self.max_steps, hints[0].data.shape[0]) | |
self.max_steps = hints[0].data.shape[0] | |
else: | |
if batch_size > self._num_samples: | |
raise ValueError( | |
f'Batch size {batch_size} > dataset size {self._num_samples}.') | |
# Returns a fixed-size random batch. | |
indices = self._rng.choice(self._num_samples, (batch_size,), | |
replace=True) | |
inputs = _subsample_data(self._inputs, indices, axis=0) | |
outputs = _subsample_data(self._outputs, indices, axis=0) | |
hints = _subsample_data(self._hints, indices, axis=1) | |
lengths = self._lengths[indices] | |
else: | |
# Returns the full dataset. | |
assert self._num_samples >= 0 | |
inputs = self._inputs | |
hints = self._hints | |
lengths = self._lengths | |
outputs = self._outputs | |
return Feedback(Features(inputs, hints, lengths), outputs) | |
def _sample_data(self, length: int, *args, **kwargs) -> List[_Array]: | |
pass | |
def _random_sequence(self, length, low=0.0, high=1.0): | |
"""Random sequence.""" | |
return self._rng.uniform(low=low, high=high, size=(length,)) | |
def _random_string(self, length, chars=4): | |
"""Random string.""" | |
return self._rng.randint(0, high=chars, size=(length,)) | |
def _random_er_graph(self, nb_nodes, p=0.5, directed=False, acyclic=False, | |
weighted=False, low=0.0, high=1.0): | |
"""Random Erdos-Renyi graph.""" | |
mat = self._rng.binomial(1, p, size=(nb_nodes, nb_nodes)) | |
if not directed: | |
mat *= np.transpose(mat) | |
elif acyclic: | |
mat = np.triu(mat, k=1) | |
p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions | |
mat = mat[p, :][:, p] | |
if weighted: | |
weights = self._rng.uniform(low=low, high=high, size=(nb_nodes, nb_nodes)) | |
if not directed: | |
weights *= np.transpose(weights) | |
weights = np.sqrt(weights + 1e-3) # Add epsilon to protect underflow | |
mat = mat.astype(float) * weights | |
return mat | |
def _random_community_graph(self, nb_nodes, k=4, p=0.5, eps=0.01, | |
directed=False, acyclic=False, weighted=False, | |
low=0.0, high=1.0): | |
"""Random perturbed k-community graph.""" | |
mat = np.zeros((nb_nodes, nb_nodes)) | |
if k > nb_nodes: | |
raise ValueError(f'Cannot generate graph of too many ({k}) communities.') | |
los, his = [], [] | |
lo = 0 | |
for i in range(k): | |
if i == k - 1: | |
hi = nb_nodes | |
else: | |
hi = lo + nb_nodes // k | |
mat[lo:hi, lo:hi] = self._random_er_graph( | |
hi - lo, p=p, directed=directed, | |
acyclic=acyclic, weighted=weighted, | |
low=low, high=high) | |
los.append(lo) | |
his.append(hi) | |
lo = hi | |
toggle = self._random_er_graph(nb_nodes, p=eps, directed=directed, | |
acyclic=acyclic, weighted=weighted, | |
low=low, high=high) | |
# Prohibit closing new cycles | |
for i in range(k): | |
for j in range(i): | |
toggle[los[i]:his[i], los[j]:his[j]] *= 0 | |
mat = np.where(toggle > 0.0, (1.0 - (mat > 0.0)) * toggle, mat) | |
p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions | |
mat = mat[p, :][:, p] | |
return mat | |
def _random_bipartite_graph(self, n, m, p=0.25): | |
"""Random bipartite graph-based flow network.""" | |
nb_nodes = n + m + 2 | |
s = 0 | |
t = n + m + 1 | |
mat = np.zeros((nb_nodes, nb_nodes)) | |
mat[s, 1:n+1] = 1.0 # supersource | |
mat[n+1:n+m+1, t] = 1.0 # supersink | |
mat[1:n+1, n+1:n+m+1] = self._rng.binomial(1, p, size=(n, m)) | |
return mat | |
def build_sampler( | |
name: str, | |
num_samples: int, | |
*args, | |
seed: Optional[int] = None, | |
**kwargs, | |
) -> Tuple[Sampler, specs.Spec]: | |
"""Builds a sampler. See `Sampler` documentation.""" | |
if name not in specs.SPECS or name not in SAMPLERS: | |
raise NotImplementedError(f'No implementation of algorithm {name}.') | |
spec = specs.SPECS[name] | |
algorithm = getattr(algorithms, name) | |
sampler_class = SAMPLERS[name] | |
# Ignore kwargs not accepted by the sampler. | |
sampler_args = inspect.signature(sampler_class._sample_data).parameters # pylint:disable=protected-access | |
clean_kwargs = {k: kwargs[k] for k in kwargs if k in sampler_args} | |
if set(clean_kwargs) != set(kwargs): | |
logging.warning('Ignoring kwargs %s when building sampler class %s', | |
set(kwargs).difference(clean_kwargs), sampler_class) | |
sampler = sampler_class(algorithm, spec, num_samples, seed=seed, | |
*args, **clean_kwargs) | |
return sampler, spec | |
class SortingSampler(Sampler): | |
"""Sorting sampler. Generates a random sequence of U[0, 1].""" | |
def _sample_data( | |
self, | |
length: int, | |
low: float = 0., | |
high: float = 1., | |
): | |
arr = self._random_sequence(length=length, low=low, high=high) | |
return [arr] | |
class SearchSampler(Sampler): | |
"""Search sampler. Generates a random sequence and target (of U[0, 1]).""" | |
def _sample_data( | |
self, | |
length: int, | |
low: float = 0., | |
high: float = 1., | |
): | |
arr = self._random_sequence(length=length, low=low, high=high) | |
arr.sort() | |
x = self._rng.uniform(low=low, high=high) | |
return [x, arr] | |
class MaxSubarraySampler(Sampler): | |
"""Maximum subarray sampler. Generates a random sequence of U[-1, 1].""" | |
def _sample_data( | |
self, | |
length: int, | |
low: float = -1., | |
high: float = 1., | |
): | |
arr = self._random_sequence(length=length, low=low, high=high) | |
return [arr] | |
class LCSSampler(Sampler): | |
"""Longest Common Subsequence sampler. Generates two random ATCG strings.""" | |
def _sample_data( | |
self, | |
length: int, | |
length_2: Optional[int] = None, | |
chars: int = 4, | |
): | |
if length_2 is None: | |
# Assume provided length is total length. | |
length_2 = length // 2 | |
length -= length_2 | |
a = self._random_string(length=length, chars=chars) | |
b = self._random_string(length=length_2, chars=chars) | |
return [a, b] | |
class OptimalBSTSampler(Sampler): | |
"""Optimal BST sampler. Samples array of probabilities, splits it into two.""" | |
def _sample_data( | |
self, | |
length: int, | |
): | |
tot_length = length + (length + 1) | |
arr = self._random_sequence(length=tot_length, low=0.0, high=1.0) | |
arr /= np.sum(arr) | |
p = arr[:length] | |
q = arr[length:] | |
return [p, q] | |
class ActivitySampler(Sampler): | |
"""Activity sampler. Samples start and finish times from U[0, 1].""" | |
def _sample_data( | |
self, | |
length: int, | |
low: float = 0., | |
high: float = 1., | |
): | |
arr_1 = self._random_sequence(length=length, low=low, high=high) | |
arr_2 = self._random_sequence(length=length, low=low, high=high) | |
return [np.minimum(arr_1, arr_2), np.maximum(arr_1, arr_2)] | |
class TaskSampler(Sampler): | |
"""Task sampler. Samples deadlines (integers) and values (U[0, 1]).""" | |
def _sample_data( | |
self, | |
length: int, | |
max_deadline: Optional[int] = None, | |
low: float = 0., | |
high: float = 1., | |
): | |
if max_deadline is None: | |
max_deadline = length | |
d = self._random_string(length=length, chars=max_deadline) + 1 | |
w = self._random_sequence(length=length, low=low, high=high) | |
return [d, w] | |
class DfsSampler(Sampler): | |
"""DFS sampler.""" | |
def _sample_data( | |
self, | |
length: int, | |
p: Tuple[float, ...] = (0.5,), | |
): | |
graph = self._random_er_graph( | |
nb_nodes=length, p=self._rng.choice(p), | |
directed=True, acyclic=False, weighted=False) | |
return [graph] | |
class BfsSampler(Sampler): | |
"""BFS sampler.""" | |
def _sample_data( | |
self, | |
length: int, | |
p: Tuple[float, ...] = (0.5,), | |
): | |
graph = self._random_er_graph( | |
nb_nodes=length, p=self._rng.choice(p), | |
directed=False, acyclic=False, weighted=False) | |
source_node = self._rng.choice(length) | |
return [graph, source_node] | |
class TopoSampler(Sampler): | |
"""Topological Sorting sampler.""" | |
def _sample_data( | |
self, | |
length: int, | |
p: Tuple[float, ...] = (0.5,), | |
): | |
graph = self._random_er_graph( | |
nb_nodes=length, p=self._rng.choice(p), | |
directed=True, acyclic=True, weighted=False) | |
return [graph] | |
class ArticulationSampler(Sampler): | |
"""Articulation Point sampler.""" | |
def _sample_data( | |
self, | |
length: int, | |
p: Tuple[float, ...] = (0.2,), | |
): | |
graph = self._random_er_graph( | |
nb_nodes=length, p=self._rng.choice(p), directed=False, | |
acyclic=False, weighted=False) | |
return [graph] | |
class MSTSampler(Sampler): | |
"""MST sampler for Kruskal's algorithm.""" | |
def _sample_data( | |
self, | |
length: int, | |
p: Tuple[float, ...] = (0.2,), # lower p to account for class imbalance | |
low: float = 0., | |
high: float = 1., | |
): | |
graph = self._random_er_graph( | |
nb_nodes=length, | |
p=self._rng.choice(p), | |
directed=False, | |
acyclic=False, | |
weighted=True, | |
low=low, | |
high=high) | |
return [graph] | |
class BellmanFordSampler(Sampler): | |
"""Bellman-Ford sampler.""" | |
def _sample_data( | |
self, | |
length: int, | |
p: Tuple[float, ...] = (0.5,), | |
low: float = 0., | |
high: float = 1., | |
): | |
graph = self._random_er_graph( | |
nb_nodes=length, | |
p=self._rng.choice(p), | |
directed=False, | |
acyclic=False, | |
weighted=True, | |
low=low, | |
high=high) | |
source_node = self._rng.choice(length) | |
return [graph, source_node] | |
class DAGPathSampler(Sampler): | |
"""Sampler for DAG shortest paths.""" | |
def _sample_data( | |
self, | |
length: int, | |
p: Tuple[float, ...] = (0.5,), | |
low: float = 0., | |
high: float = 1., | |
): | |
graph = self._random_er_graph( | |
nb_nodes=length, | |
p=self._rng.choice(p), | |
directed=True, | |
acyclic=True, | |
weighted=True, | |
low=low, | |
high=high) | |
source_node = self._rng.choice(length) | |
return [graph, source_node] | |
class FloydWarshallSampler(Sampler): | |
"""Sampler for all-pairs shortest paths.""" | |
def _sample_data( | |
self, | |
length: int, | |
p: Tuple[float, ...] = (0.5,), | |
low: float = 0., | |
high: float = 1., | |
): | |
graph = self._random_er_graph( | |
nb_nodes=length, | |
p=self._rng.choice(p), | |
directed=False, | |
acyclic=False, | |
weighted=True, | |
low=low, | |
high=high) | |
return [graph] | |
class SccSampler(Sampler): | |
"""Sampler for strongly connected component (SCC) tasks.""" | |
def _sample_data( | |
self, | |
length: int, | |
k: int = 4, | |
p: Tuple[float, ...] = (0.5,), | |
eps: float = 0.01, | |
): | |
graph = self._random_community_graph( | |
nb_nodes=length, k=k, p=self._rng.choice(p), eps=eps, | |
directed=True, acyclic=False, weighted=False) | |
return [graph] | |
class BipartiteSampler(Sampler): | |
"""Sampler for bipartite matching-based flow networks.""" | |
def _sample_data( | |
self, | |
length: int, | |
length_2: Optional[int] = None, | |
p: Tuple[float, ...] = (0.3,), | |
): | |
if length_2 is None: | |
# Assume provided length is total length. | |
length_2 = length // 2 | |
length -= length_2 | |
graph = self._random_bipartite_graph(n=length, m=length_2, | |
p=self._rng.choice(p)) | |
return [graph, length, length_2, 0, length + length_2 + 1] | |
class MatcherSampler(Sampler): | |
"""String matching sampler; embeds needle in a random haystack.""" | |
def _sample_data( | |
self, | |
length: int, # length of haystack + needle, i.e., total number of nodes | |
length_needle: Optional[int] = None, | |
chars: int = 4, | |
): | |
if length_needle is None: | |
if length < 5: | |
length_needle = 1 | |
else: | |
length_needle = length // 5 | |
elif length_needle < 0: # randomize needle length | |
length_needle = self._rng.randint(1, high=1 - length_needle) | |
length_haystack = length - length_needle | |
needle = self._random_string(length=length_needle, chars=chars) | |
haystack = self._random_string(length=length_haystack, chars=chars) | |
embed_pos = self._rng.choice(length_haystack - length_needle) | |
haystack[embed_pos:embed_pos + length_needle] = needle | |
return [haystack, needle] | |
class SegmentsSampler(Sampler): | |
"""Two-segment sampler of points from (U[0, 1], U[0, 1]).""" | |
def _sample_data(self, length: int, low: float = 0., high: float = 1.): | |
del length # There are exactly four endpoints. | |
# Quick CCW check (ignoring collinearity) for rejection sampling | |
def ccw(x_a, y_a, x_b, y_b, x_c, y_c): | |
return (y_c - y_a) * (x_b - x_a) > (y_b - y_a) * (x_c - x_a) | |
def intersect(xs, ys): | |
return ccw(xs[0], ys[0], xs[2], ys[2], xs[3], ys[3]) != ccw( | |
xs[1], ys[1], xs[2], ys[2], xs[3], ys[3]) and ccw( | |
xs[0], ys[0], xs[1], ys[1], xs[2], ys[2]) != ccw( | |
xs[0], ys[0], xs[1], ys[1], xs[3], ys[3]) | |
# Decide (with uniform probability) should this sample intersect | |
coin_flip = self._rng.binomial(1, 0.5) | |
xs = self._random_sequence(length=4, low=low, high=high) | |
ys = self._random_sequence(length=4, low=low, high=high) | |
while intersect(xs, ys) != coin_flip: | |
xs = self._random_sequence(length=4, low=low, high=high) | |
ys = self._random_sequence(length=4, low=low, high=high) | |
return [xs, ys] | |
class ConvexHullSampler(Sampler): | |
"""Convex hull sampler of points over a disk of radius r.""" | |
def _sample_data(self, length: int, origin_x: float = 0., | |
origin_y: float = 0., radius: float = 2.): | |
thetas = self._random_sequence(length=length, low=0.0, high=2.0 * np.pi) | |
rs = radius * np.sqrt( | |
self._random_sequence(length=length, low=0.0, high=1.0)) | |
xs = rs * np.cos(thetas) + origin_x | |
ys = rs * np.sin(thetas) + origin_y | |
return [xs, ys] | |
SAMPLERS = { | |
'insertion_sort': SortingSampler, | |
'bubble_sort': SortingSampler, | |
'heapsort': SortingSampler, | |
'quicksort': SortingSampler, | |
'quickselect': SortingSampler, | |
'minimum': SortingSampler, | |
'binary_search': SearchSampler, | |
'find_maximum_subarray': MaxSubarraySampler, | |
'find_maximum_subarray_kadane': MaxSubarraySampler, | |
'matrix_chain_order': SortingSampler, | |
'lcs_length': LCSSampler, | |
'optimal_bst': OptimalBSTSampler, | |
'activity_selector': ActivitySampler, | |
'task_scheduling': TaskSampler, | |
'dfs': DfsSampler, | |
'topological_sort': TopoSampler, | |
'strongly_connected_components': SccSampler, | |
'articulation_points': ArticulationSampler, | |
'bridges': ArticulationSampler, | |
'bfs': BfsSampler, | |
'mst_kruskal': MSTSampler, | |
'mst_prim': BellmanFordSampler, | |
'bellman_ford': BellmanFordSampler, | |
'dag_shortest_paths': DAGPathSampler, | |
'dijkstra': BellmanFordSampler, | |
'floyd_warshall': FloydWarshallSampler, | |
'bipartite_matching': BipartiteSampler, | |
'naive_string_matcher': MatcherSampler, | |
'kmp_matcher': MatcherSampler, | |
'segments_intersect': SegmentsSampler, | |
'graham_scan': ConvexHullSampler, | |
'jarvis_march': ConvexHullSampler, | |
} | |
def _batch_io(traj_io: Trajectories) -> Trajectory: | |
"""Batches a trajectory of input/output samples along the time axis per probe. | |
Args: | |
traj_io: An i/o trajectory of `DataPoint`s indexed by time then probe. | |
Returns: | |
A |num probes| list of `DataPoint`s with the time axis stacked into `data`. | |
""" | |
assert traj_io # non-empty | |
for sample_io in traj_io: | |
for i, dp in enumerate(sample_io): | |
assert dp.data.shape[0] == 1 # batching axis | |
assert traj_io[0][i].name == dp.name | |
return jax.tree_util.tree_map(lambda *x: np.concatenate(x), *traj_io) | |
def _batch_hints( | |
traj_hints: Trajectories, min_steps: int) -> Tuple[Trajectory, List[int]]: | |
"""Batches a trajectory of hints samples along the time axis per probe. | |
Unlike i/o, hints have a variable-length time dimension. Before batching, each | |
trajectory is padded to the maximum trajectory length. | |
Args: | |
traj_hints: A hint trajectory of `DataPoints`s indexed by time then probe | |
min_steps: Hints will be padded at least to this length - if any hint is | |
longer than this, the greater length will be used. | |
Returns: | |
A |num probes| list of `DataPoint`s with the time axis stacked into `data`, | |
and a |sample| list containing the length of each trajectory. | |
""" | |
max_steps = min_steps | |
assert traj_hints # non-empty | |
for sample_hint in traj_hints: | |
for dp in sample_hint: | |
assert dp.data.shape[1] == 1 # batching axis | |
if dp.data.shape[0] > max_steps: | |
max_steps = dp.data.shape[0] | |
time_and_batch = (max_steps, len(traj_hints)) | |
# Create zero-filled space for the batched hints, then copy each hint | |
# up to the corresponding length. | |
batched_traj = jax.tree_util.tree_map( | |
lambda x: np.zeros(time_and_batch + x.shape[2:]), | |
traj_hints[0]) | |
hint_lengths = np.zeros(len(traj_hints)) | |
for sample_idx, cur_sample in enumerate(traj_hints): | |
for i in range(len(cur_sample)): | |
assert batched_traj[i].name == cur_sample[i].name | |
cur_data = cur_sample[i].data | |
cur_length = cur_data.shape[0] | |
batched_traj[i].data[:cur_length, sample_idx:sample_idx+1] = cur_data | |
if i > 0: | |
assert hint_lengths[sample_idx] == cur_length | |
else: | |
hint_lengths[sample_idx] = cur_length | |
return batched_traj, hint_lengths | |
def _subsample_data( | |
trajectory: Trajectory, | |
idx: List[int], | |
axis: int = 0, | |
) -> Trajectory: | |
"""New `Trajectory` where each `DataPoint`'s data is subsampled along axis.""" | |
sampled_traj = [] | |
for dp in trajectory: | |
sampled_data = np.take(dp.data, idx, axis=axis) | |
sampled_traj.append( | |
probing.DataPoint(dp.name, dp.location, dp.type_, sampled_data)) | |
return sampled_traj | |
def _preprocess_permutations(probes, enforce_permutations): | |
"""Replace should-be permutations with proper permutation pointer + mask.""" | |
output = [] | |
for x in probes: | |
if x.type_ != specs.Type.SHOULD_BE_PERMUTATION: | |
output.append(x) | |
continue | |
assert x.location == specs.Location.NODE | |
if enforce_permutations: | |
new_x, mask = probing.predecessor_to_cyclic_predecessor_and_first(x.data) | |
output.append( | |
probing.DataPoint( | |
name=x.name, | |
location=x.location, | |
type_=specs.Type.PERMUTATION_POINTER, | |
data=new_x)) | |
output.append( | |
probing.DataPoint( | |
name=x.name + '_mask', | |
location=x.location, | |
type_=specs.Type.MASK_ONE, | |
data=mask)) | |
else: | |
output.append(probing.DataPoint(name=x.name, location=x.location, | |
type_=specs.Type.POINTER, data=x.data)) | |
return output | |
def process_permutations(spec, sample_iterator, enforce_permutations): | |
"""Replace should-be permutations with proper permutation pointer + mask.""" | |
def _iterate(): | |
while True: | |
feedback = next(sample_iterator) | |
features = feedback.features | |
inputs = _preprocess_permutations(features.inputs, enforce_permutations) | |
hints = _preprocess_permutations(features.hints, enforce_permutations) | |
outputs = _preprocess_permutations(feedback.outputs, enforce_permutations) | |
features = features._replace(inputs=tuple(inputs), | |
hints=tuple(hints)) | |
feedback = feedback._replace(features=features, | |
outputs=outputs) | |
yield feedback | |
new_spec = {} | |
for k in spec: | |
if (spec[k][1] == specs.Location.NODE and | |
spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION): | |
if enforce_permutations: | |
new_spec[k] = (spec[k][0], spec[k][1], specs.Type.PERMUTATION_POINTER) | |
new_spec[k + '_mask'] = (spec[k][0], spec[k][1], specs.Type.MASK_ONE) | |
else: | |
new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER) | |
else: | |
new_spec[k] = spec[k] | |
return new_spec, _iterate() | |
def process_pred_as_input(spec, sample_iterator): | |
"""Move pred_h hint to pred input.""" | |
def _iterate(): | |
while True: | |
feedback = next(sample_iterator) | |
features = feedback.features | |
pred_h = [h for h in features.hints if h.name == 'pred_h'] | |
if pred_h: | |
assert len(pred_h) == 1 | |
pred_h = pred_h[0] | |
hints = [h for h in features.hints if h.name != 'pred_h'] | |
for i in range(len(features.lengths)): | |
assert np.sum(np.abs(pred_h.data[1:int(features.lengths[i]), i] - | |
pred_h.data[0, i])) == 0.0 | |
inputs = tuple(features.inputs) + ( | |
probing.DataPoint(name='pred', location=pred_h.location, | |
type_=pred_h.type_, data=pred_h.data[0]),) | |
features = features._replace(inputs=tuple(inputs), | |
hints=tuple(hints)) | |
feedback = feedback._replace(features=features) | |
yield feedback | |
new_spec = {} | |
for k in spec: | |
if k == 'pred_h': | |
assert spec[k] == (specs.Stage.HINT, specs.Location.NODE, | |
specs.Type.POINTER) | |
new_spec['pred'] = (specs.Stage.INPUT, specs.Location.NODE, | |
specs.Type.POINTER) | |
else: | |
new_spec[k] = spec[k] | |
return new_spec, _iterate() | |
def process_random_pos(sample_iterator, rng): | |
"""Randomize the `pos` input from a sampler. | |
The `pos` input is, by default, a scalar uniformly spaced between 0 and 1 | |
across the nodes. The exception are string algorithms (naive_string_matcher, | |
kmp_string_matcher and lcs_length), where the `pos` sequence is split into | |
needle and haystack (or first and second string, for lcs_length). Here | |
we replace the uniformly spaced `pos` with an ordered sequence of random | |
scalars, or, for string algorithms, two ordered sequences of random scalars. | |
Args: | |
sample_iterator: An iterator producing samples with non-random `pos` inputs. | |
rng: Numpy random generator | |
Returns: | |
An iterator returning the samples with randomized `pos` inputs. | |
""" | |
def _iterate(): | |
while True: | |
feedback = next(sample_iterator) | |
inputs = feedback.features.inputs | |
pos, = [x for x in inputs if x.name == 'pos'] | |
batch_size, num_nodes = pos.data.shape | |
unsorted = rng.uniform(size=(batch_size, num_nodes)) | |
new_pos = [] | |
for i in range(batch_size): # we check one example at a time. | |
# We find if there are splits in the pos sequence, marked by zeros. | |
# We know there will always be at least 1 zero, if there's no split. | |
split, = np.where(pos.data[i] == 0) | |
split = np.concatenate([split, [num_nodes]]) | |
# We construct the randomized pos by sorting the random values in each | |
# split and concatenating them. | |
new_pos.append( | |
np.concatenate([np.sort(unsorted[i, split[j]:split[j+1]]) | |
for j in range(len(split) - 1)])) | |
pos.data = np.array(new_pos) | |
inputs = [(pos if x.name == 'pos' else x) for x in inputs] | |
features = feedback.features._replace(inputs=inputs) | |
feedback = feedback._replace(features=features) | |
yield feedback | |
return _iterate() | |