Lim0011's picture
Upload 251 files
85e3d20 verified
raw
history blame
28.2 kB
# Copyright 2021 DeepMind Technologies Limited. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""Sampling utilities."""
import abc
import collections
import inspect
import types
from typing import Any, Callable, List, Optional, Tuple
from absl import logging
from clrs._src import algorithms
from clrs._src import probing
from clrs._src import specs
import jax
import numpy as np
_Array = np.ndarray
_DataPoint = probing.DataPoint
Trajectory = List[_DataPoint]
Trajectories = List[Trajectory]
Algorithm = Callable[..., Any]
Features = collections.namedtuple('Features', ['inputs', 'hints', 'lengths'])
FeaturesChunked = collections.namedtuple(
'Features', ['inputs', 'hints', 'is_first', 'is_last'])
Feedback = collections.namedtuple('Feedback', ['features', 'outputs'])
# CLRS-30 baseline spec.
CLRS30 = types.MappingProxyType({
'train': {
'num_samples': 1000,
'length': 16,
'seed': 1,
},
'val': {
'num_samples': 32,
'length': 16,
'seed': 2,
},
'test': {
'num_samples': 32,
'length': 64,
'seed': 3,
},
})
class Sampler(abc.ABC):
"""Sampler abstract base class."""
def __init__(
self,
algorithm: Algorithm,
spec: specs.Spec,
num_samples: int,
*args,
seed: Optional[int] = None,
**kwargs,
):
"""Initializes a `Sampler`.
Args:
algorithm: The algorithm to sample from
spec: The algorithm spec.
num_samples: Number of algorithm unrolls to sample. If positive, all the
samples will be generated in the constructor, and at each call
of the `next` method a batch will be randomly selected among them.
If -1, samples are generated on the fly with each call to `next`.
*args: Algorithm args.
seed: RNG seed.
**kwargs: Algorithm kwargs.
"""
# Use `RandomState` to ensure deterministic sampling across Numpy versions.
self._rng = np.random.RandomState(seed)
self._spec = spec
self._num_samples = num_samples
self._algorithm = algorithm
self._args = args
self._kwargs = kwargs
if num_samples < 0:
logging.warning('Sampling dataset on-the-fly, unlimited samples.')
# Just get an initial estimate of max hint length
self.max_steps = -1
for _ in range(1000):
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
_, _, hint = probing.split_stages(probes, spec)
for dp in hint:
assert dp.data.shape[1] == 1 # batching axis
if dp.data.shape[0] > self.max_steps:
self.max_steps = dp.data.shape[0]
else:
logging.info('Creating a dataset with %i samples.', num_samples)
(self._inputs, self._outputs, self._hints,
self._lengths) = self._make_batch(num_samples, spec, 0, algorithm, *args,
**kwargs)
def _make_batch(self, num_samples: int, spec: specs.Spec, min_length: int,
algorithm: Algorithm, *args, **kwargs):
"""Generate a batch of data."""
inputs = []
outputs = []
hints = []
for _ in range(num_samples):
data = self._sample_data(*args, **kwargs)
_, probes = algorithm(*data)
inp, outp, hint = probing.split_stages(probes, spec)
inputs.append(inp)
outputs.append(outp)
hints.append(hint)
if len(hints) % 1000 == 0:
logging.info('%i samples created', len(hints))
# Batch and pad trajectories to max(T).
inputs = _batch_io(inputs)
outputs = _batch_io(outputs)
hints, lengths = _batch_hints(hints, min_length)
return inputs, outputs, hints, lengths
def next(self, batch_size: Optional[int] = None) -> Feedback:
"""Subsamples trajectories from the pre-generated dataset.
Args:
batch_size: Optional batch size. If `None`, returns entire dataset.
Returns:
Subsampled trajectories.
"""
if batch_size:
if self._num_samples < 0: # generate on the fly
inputs, outputs, hints, lengths = self._make_batch(
batch_size, self._spec, self.max_steps,
self._algorithm, *self._args, **self._kwargs)
if hints[0].data.shape[0] > self.max_steps:
logging.warning('Increasing hint lengh from %i to %i',
self.max_steps, hints[0].data.shape[0])
self.max_steps = hints[0].data.shape[0]
else:
if batch_size > self._num_samples:
raise ValueError(
f'Batch size {batch_size} > dataset size {self._num_samples}.')
# Returns a fixed-size random batch.
indices = self._rng.choice(self._num_samples, (batch_size,),
replace=True)
inputs = _subsample_data(self._inputs, indices, axis=0)
outputs = _subsample_data(self._outputs, indices, axis=0)
hints = _subsample_data(self._hints, indices, axis=1)
lengths = self._lengths[indices]
else:
# Returns the full dataset.
assert self._num_samples >= 0
inputs = self._inputs
hints = self._hints
lengths = self._lengths
outputs = self._outputs
return Feedback(Features(inputs, hints, lengths), outputs)
@abc.abstractmethod
def _sample_data(self, length: int, *args, **kwargs) -> List[_Array]:
pass
def _random_sequence(self, length, low=0.0, high=1.0):
"""Random sequence."""
return self._rng.uniform(low=low, high=high, size=(length,))
def _random_string(self, length, chars=4):
"""Random string."""
return self._rng.randint(0, high=chars, size=(length,))
def _random_er_graph(self, nb_nodes, p=0.5, directed=False, acyclic=False,
weighted=False, low=0.0, high=1.0):
"""Random Erdos-Renyi graph."""
mat = self._rng.binomial(1, p, size=(nb_nodes, nb_nodes))
if not directed:
mat *= np.transpose(mat)
elif acyclic:
mat = np.triu(mat, k=1)
p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions
mat = mat[p, :][:, p]
if weighted:
weights = self._rng.uniform(low=low, high=high, size=(nb_nodes, nb_nodes))
if not directed:
weights *= np.transpose(weights)
weights = np.sqrt(weights + 1e-3) # Add epsilon to protect underflow
mat = mat.astype(float) * weights
return mat
def _random_community_graph(self, nb_nodes, k=4, p=0.5, eps=0.01,
directed=False, acyclic=False, weighted=False,
low=0.0, high=1.0):
"""Random perturbed k-community graph."""
mat = np.zeros((nb_nodes, nb_nodes))
if k > nb_nodes:
raise ValueError(f'Cannot generate graph of too many ({k}) communities.')
los, his = [], []
lo = 0
for i in range(k):
if i == k - 1:
hi = nb_nodes
else:
hi = lo + nb_nodes // k
mat[lo:hi, lo:hi] = self._random_er_graph(
hi - lo, p=p, directed=directed,
acyclic=acyclic, weighted=weighted,
low=low, high=high)
los.append(lo)
his.append(hi)
lo = hi
toggle = self._random_er_graph(nb_nodes, p=eps, directed=directed,
acyclic=acyclic, weighted=weighted,
low=low, high=high)
# Prohibit closing new cycles
for i in range(k):
for j in range(i):
toggle[los[i]:his[i], los[j]:his[j]] *= 0
mat = np.where(toggle > 0.0, (1.0 - (mat > 0.0)) * toggle, mat)
p = self._rng.permutation(nb_nodes) # To allow nontrivial solutions
mat = mat[p, :][:, p]
return mat
def _random_bipartite_graph(self, n, m, p=0.25):
"""Random bipartite graph-based flow network."""
nb_nodes = n + m + 2
s = 0
t = n + m + 1
mat = np.zeros((nb_nodes, nb_nodes))
mat[s, 1:n+1] = 1.0 # supersource
mat[n+1:n+m+1, t] = 1.0 # supersink
mat[1:n+1, n+1:n+m+1] = self._rng.binomial(1, p, size=(n, m))
return mat
def build_sampler(
name: str,
num_samples: int,
*args,
seed: Optional[int] = None,
**kwargs,
) -> Tuple[Sampler, specs.Spec]:
"""Builds a sampler. See `Sampler` documentation."""
if name not in specs.SPECS or name not in SAMPLERS:
raise NotImplementedError(f'No implementation of algorithm {name}.')
spec = specs.SPECS[name]
algorithm = getattr(algorithms, name)
sampler_class = SAMPLERS[name]
# Ignore kwargs not accepted by the sampler.
sampler_args = inspect.signature(sampler_class._sample_data).parameters # pylint:disable=protected-access
clean_kwargs = {k: kwargs[k] for k in kwargs if k in sampler_args}
if set(clean_kwargs) != set(kwargs):
logging.warning('Ignoring kwargs %s when building sampler class %s',
set(kwargs).difference(clean_kwargs), sampler_class)
sampler = sampler_class(algorithm, spec, num_samples, seed=seed,
*args, **clean_kwargs)
return sampler, spec
class SortingSampler(Sampler):
"""Sorting sampler. Generates a random sequence of U[0, 1]."""
def _sample_data(
self,
length: int,
low: float = 0.,
high: float = 1.,
):
arr = self._random_sequence(length=length, low=low, high=high)
return [arr]
class SearchSampler(Sampler):
"""Search sampler. Generates a random sequence and target (of U[0, 1])."""
def _sample_data(
self,
length: int,
low: float = 0.,
high: float = 1.,
):
arr = self._random_sequence(length=length, low=low, high=high)
arr.sort()
x = self._rng.uniform(low=low, high=high)
return [x, arr]
class MaxSubarraySampler(Sampler):
"""Maximum subarray sampler. Generates a random sequence of U[-1, 1]."""
def _sample_data(
self,
length: int,
low: float = -1.,
high: float = 1.,
):
arr = self._random_sequence(length=length, low=low, high=high)
return [arr]
class LCSSampler(Sampler):
"""Longest Common Subsequence sampler. Generates two random ATCG strings."""
def _sample_data(
self,
length: int,
length_2: Optional[int] = None,
chars: int = 4,
):
if length_2 is None:
# Assume provided length is total length.
length_2 = length // 2
length -= length_2
a = self._random_string(length=length, chars=chars)
b = self._random_string(length=length_2, chars=chars)
return [a, b]
class OptimalBSTSampler(Sampler):
"""Optimal BST sampler. Samples array of probabilities, splits it into two."""
def _sample_data(
self,
length: int,
):
tot_length = length + (length + 1)
arr = self._random_sequence(length=tot_length, low=0.0, high=1.0)
arr /= np.sum(arr)
p = arr[:length]
q = arr[length:]
return [p, q]
class ActivitySampler(Sampler):
"""Activity sampler. Samples start and finish times from U[0, 1]."""
def _sample_data(
self,
length: int,
low: float = 0.,
high: float = 1.,
):
arr_1 = self._random_sequence(length=length, low=low, high=high)
arr_2 = self._random_sequence(length=length, low=low, high=high)
return [np.minimum(arr_1, arr_2), np.maximum(arr_1, arr_2)]
class TaskSampler(Sampler):
"""Task sampler. Samples deadlines (integers) and values (U[0, 1])."""
def _sample_data(
self,
length: int,
max_deadline: Optional[int] = None,
low: float = 0.,
high: float = 1.,
):
if max_deadline is None:
max_deadline = length
d = self._random_string(length=length, chars=max_deadline) + 1
w = self._random_sequence(length=length, low=low, high=high)
return [d, w]
class DfsSampler(Sampler):
"""DFS sampler."""
def _sample_data(
self,
length: int,
p: Tuple[float, ...] = (0.5,),
):
graph = self._random_er_graph(
nb_nodes=length, p=self._rng.choice(p),
directed=True, acyclic=False, weighted=False)
return [graph]
class BfsSampler(Sampler):
"""BFS sampler."""
def _sample_data(
self,
length: int,
p: Tuple[float, ...] = (0.5,),
):
graph = self._random_er_graph(
nb_nodes=length, p=self._rng.choice(p),
directed=False, acyclic=False, weighted=False)
source_node = self._rng.choice(length)
return [graph, source_node]
class TopoSampler(Sampler):
"""Topological Sorting sampler."""
def _sample_data(
self,
length: int,
p: Tuple[float, ...] = (0.5,),
):
graph = self._random_er_graph(
nb_nodes=length, p=self._rng.choice(p),
directed=True, acyclic=True, weighted=False)
return [graph]
class ArticulationSampler(Sampler):
"""Articulation Point sampler."""
def _sample_data(
self,
length: int,
p: Tuple[float, ...] = (0.2,),
):
graph = self._random_er_graph(
nb_nodes=length, p=self._rng.choice(p), directed=False,
acyclic=False, weighted=False)
return [graph]
class MSTSampler(Sampler):
"""MST sampler for Kruskal's algorithm."""
def _sample_data(
self,
length: int,
p: Tuple[float, ...] = (0.2,), # lower p to account for class imbalance
low: float = 0.,
high: float = 1.,
):
graph = self._random_er_graph(
nb_nodes=length,
p=self._rng.choice(p),
directed=False,
acyclic=False,
weighted=True,
low=low,
high=high)
return [graph]
class BellmanFordSampler(Sampler):
"""Bellman-Ford sampler."""
def _sample_data(
self,
length: int,
p: Tuple[float, ...] = (0.5,),
low: float = 0.,
high: float = 1.,
):
graph = self._random_er_graph(
nb_nodes=length,
p=self._rng.choice(p),
directed=False,
acyclic=False,
weighted=True,
low=low,
high=high)
source_node = self._rng.choice(length)
return [graph, source_node]
class DAGPathSampler(Sampler):
"""Sampler for DAG shortest paths."""
def _sample_data(
self,
length: int,
p: Tuple[float, ...] = (0.5,),
low: float = 0.,
high: float = 1.,
):
graph = self._random_er_graph(
nb_nodes=length,
p=self._rng.choice(p),
directed=True,
acyclic=True,
weighted=True,
low=low,
high=high)
source_node = self._rng.choice(length)
return [graph, source_node]
class FloydWarshallSampler(Sampler):
"""Sampler for all-pairs shortest paths."""
def _sample_data(
self,
length: int,
p: Tuple[float, ...] = (0.5,),
low: float = 0.,
high: float = 1.,
):
graph = self._random_er_graph(
nb_nodes=length,
p=self._rng.choice(p),
directed=False,
acyclic=False,
weighted=True,
low=low,
high=high)
return [graph]
class SccSampler(Sampler):
"""Sampler for strongly connected component (SCC) tasks."""
def _sample_data(
self,
length: int,
k: int = 4,
p: Tuple[float, ...] = (0.5,),
eps: float = 0.01,
):
graph = self._random_community_graph(
nb_nodes=length, k=k, p=self._rng.choice(p), eps=eps,
directed=True, acyclic=False, weighted=False)
return [graph]
class BipartiteSampler(Sampler):
"""Sampler for bipartite matching-based flow networks."""
def _sample_data(
self,
length: int,
length_2: Optional[int] = None,
p: Tuple[float, ...] = (0.3,),
):
if length_2 is None:
# Assume provided length is total length.
length_2 = length // 2
length -= length_2
graph = self._random_bipartite_graph(n=length, m=length_2,
p=self._rng.choice(p))
return [graph, length, length_2, 0, length + length_2 + 1]
class MatcherSampler(Sampler):
"""String matching sampler; embeds needle in a random haystack."""
def _sample_data(
self,
length: int, # length of haystack + needle, i.e., total number of nodes
length_needle: Optional[int] = None,
chars: int = 4,
):
if length_needle is None:
if length < 5:
length_needle = 1
else:
length_needle = length // 5
elif length_needle < 0: # randomize needle length
length_needle = self._rng.randint(1, high=1 - length_needle)
length_haystack = length - length_needle
needle = self._random_string(length=length_needle, chars=chars)
haystack = self._random_string(length=length_haystack, chars=chars)
embed_pos = self._rng.choice(length_haystack - length_needle)
haystack[embed_pos:embed_pos + length_needle] = needle
return [haystack, needle]
class SegmentsSampler(Sampler):
"""Two-segment sampler of points from (U[0, 1], U[0, 1])."""
def _sample_data(self, length: int, low: float = 0., high: float = 1.):
del length # There are exactly four endpoints.
# Quick CCW check (ignoring collinearity) for rejection sampling
def ccw(x_a, y_a, x_b, y_b, x_c, y_c):
return (y_c - y_a) * (x_b - x_a) > (y_b - y_a) * (x_c - x_a)
def intersect(xs, ys):
return ccw(xs[0], ys[0], xs[2], ys[2], xs[3], ys[3]) != ccw(
xs[1], ys[1], xs[2], ys[2], xs[3], ys[3]) and ccw(
xs[0], ys[0], xs[1], ys[1], xs[2], ys[2]) != ccw(
xs[0], ys[0], xs[1], ys[1], xs[3], ys[3])
# Decide (with uniform probability) should this sample intersect
coin_flip = self._rng.binomial(1, 0.5)
xs = self._random_sequence(length=4, low=low, high=high)
ys = self._random_sequence(length=4, low=low, high=high)
while intersect(xs, ys) != coin_flip:
xs = self._random_sequence(length=4, low=low, high=high)
ys = self._random_sequence(length=4, low=low, high=high)
return [xs, ys]
class ConvexHullSampler(Sampler):
"""Convex hull sampler of points over a disk of radius r."""
def _sample_data(self, length: int, origin_x: float = 0.,
origin_y: float = 0., radius: float = 2.):
thetas = self._random_sequence(length=length, low=0.0, high=2.0 * np.pi)
rs = radius * np.sqrt(
self._random_sequence(length=length, low=0.0, high=1.0))
xs = rs * np.cos(thetas) + origin_x
ys = rs * np.sin(thetas) + origin_y
return [xs, ys]
SAMPLERS = {
'insertion_sort': SortingSampler,
'bubble_sort': SortingSampler,
'heapsort': SortingSampler,
'quicksort': SortingSampler,
'quickselect': SortingSampler,
'minimum': SortingSampler,
'binary_search': SearchSampler,
'find_maximum_subarray': MaxSubarraySampler,
'find_maximum_subarray_kadane': MaxSubarraySampler,
'matrix_chain_order': SortingSampler,
'lcs_length': LCSSampler,
'optimal_bst': OptimalBSTSampler,
'activity_selector': ActivitySampler,
'task_scheduling': TaskSampler,
'dfs': DfsSampler,
'topological_sort': TopoSampler,
'strongly_connected_components': SccSampler,
'articulation_points': ArticulationSampler,
'bridges': ArticulationSampler,
'bfs': BfsSampler,
'mst_kruskal': MSTSampler,
'mst_prim': BellmanFordSampler,
'bellman_ford': BellmanFordSampler,
'dag_shortest_paths': DAGPathSampler,
'dijkstra': BellmanFordSampler,
'floyd_warshall': FloydWarshallSampler,
'bipartite_matching': BipartiteSampler,
'naive_string_matcher': MatcherSampler,
'kmp_matcher': MatcherSampler,
'segments_intersect': SegmentsSampler,
'graham_scan': ConvexHullSampler,
'jarvis_march': ConvexHullSampler,
}
def _batch_io(traj_io: Trajectories) -> Trajectory:
"""Batches a trajectory of input/output samples along the time axis per probe.
Args:
traj_io: An i/o trajectory of `DataPoint`s indexed by time then probe.
Returns:
A |num probes| list of `DataPoint`s with the time axis stacked into `data`.
"""
assert traj_io # non-empty
for sample_io in traj_io:
for i, dp in enumerate(sample_io):
assert dp.data.shape[0] == 1 # batching axis
assert traj_io[0][i].name == dp.name
return jax.tree_util.tree_map(lambda *x: np.concatenate(x), *traj_io)
def _batch_hints(
traj_hints: Trajectories, min_steps: int) -> Tuple[Trajectory, List[int]]:
"""Batches a trajectory of hints samples along the time axis per probe.
Unlike i/o, hints have a variable-length time dimension. Before batching, each
trajectory is padded to the maximum trajectory length.
Args:
traj_hints: A hint trajectory of `DataPoints`s indexed by time then probe
min_steps: Hints will be padded at least to this length - if any hint is
longer than this, the greater length will be used.
Returns:
A |num probes| list of `DataPoint`s with the time axis stacked into `data`,
and a |sample| list containing the length of each trajectory.
"""
max_steps = min_steps
assert traj_hints # non-empty
for sample_hint in traj_hints:
for dp in sample_hint:
assert dp.data.shape[1] == 1 # batching axis
if dp.data.shape[0] > max_steps:
max_steps = dp.data.shape[0]
time_and_batch = (max_steps, len(traj_hints))
# Create zero-filled space for the batched hints, then copy each hint
# up to the corresponding length.
batched_traj = jax.tree_util.tree_map(
lambda x: np.zeros(time_and_batch + x.shape[2:]),
traj_hints[0])
hint_lengths = np.zeros(len(traj_hints))
for sample_idx, cur_sample in enumerate(traj_hints):
for i in range(len(cur_sample)):
assert batched_traj[i].name == cur_sample[i].name
cur_data = cur_sample[i].data
cur_length = cur_data.shape[0]
batched_traj[i].data[:cur_length, sample_idx:sample_idx+1] = cur_data
if i > 0:
assert hint_lengths[sample_idx] == cur_length
else:
hint_lengths[sample_idx] = cur_length
return batched_traj, hint_lengths
def _subsample_data(
trajectory: Trajectory,
idx: List[int],
axis: int = 0,
) -> Trajectory:
"""New `Trajectory` where each `DataPoint`'s data is subsampled along axis."""
sampled_traj = []
for dp in trajectory:
sampled_data = np.take(dp.data, idx, axis=axis)
sampled_traj.append(
probing.DataPoint(dp.name, dp.location, dp.type_, sampled_data))
return sampled_traj
def _preprocess_permutations(probes, enforce_permutations):
"""Replace should-be permutations with proper permutation pointer + mask."""
output = []
for x in probes:
if x.type_ != specs.Type.SHOULD_BE_PERMUTATION:
output.append(x)
continue
assert x.location == specs.Location.NODE
if enforce_permutations:
new_x, mask = probing.predecessor_to_cyclic_predecessor_and_first(x.data)
output.append(
probing.DataPoint(
name=x.name,
location=x.location,
type_=specs.Type.PERMUTATION_POINTER,
data=new_x))
output.append(
probing.DataPoint(
name=x.name + '_mask',
location=x.location,
type_=specs.Type.MASK_ONE,
data=mask))
else:
output.append(probing.DataPoint(name=x.name, location=x.location,
type_=specs.Type.POINTER, data=x.data))
return output
def process_permutations(spec, sample_iterator, enforce_permutations):
"""Replace should-be permutations with proper permutation pointer + mask."""
def _iterate():
while True:
feedback = next(sample_iterator)
features = feedback.features
inputs = _preprocess_permutations(features.inputs, enforce_permutations)
hints = _preprocess_permutations(features.hints, enforce_permutations)
outputs = _preprocess_permutations(feedback.outputs, enforce_permutations)
features = features._replace(inputs=tuple(inputs),
hints=tuple(hints))
feedback = feedback._replace(features=features,
outputs=outputs)
yield feedback
new_spec = {}
for k in spec:
if (spec[k][1] == specs.Location.NODE and
spec[k][2] == specs.Type.SHOULD_BE_PERMUTATION):
if enforce_permutations:
new_spec[k] = (spec[k][0], spec[k][1], specs.Type.PERMUTATION_POINTER)
new_spec[k + '_mask'] = (spec[k][0], spec[k][1], specs.Type.MASK_ONE)
else:
new_spec[k] = (spec[k][0], spec[k][1], specs.Type.POINTER)
else:
new_spec[k] = spec[k]
return new_spec, _iterate()
def process_pred_as_input(spec, sample_iterator):
"""Move pred_h hint to pred input."""
def _iterate():
while True:
feedback = next(sample_iterator)
features = feedback.features
pred_h = [h for h in features.hints if h.name == 'pred_h']
if pred_h:
assert len(pred_h) == 1
pred_h = pred_h[0]
hints = [h for h in features.hints if h.name != 'pred_h']
for i in range(len(features.lengths)):
assert np.sum(np.abs(pred_h.data[1:int(features.lengths[i]), i] -
pred_h.data[0, i])) == 0.0
inputs = tuple(features.inputs) + (
probing.DataPoint(name='pred', location=pred_h.location,
type_=pred_h.type_, data=pred_h.data[0]),)
features = features._replace(inputs=tuple(inputs),
hints=tuple(hints))
feedback = feedback._replace(features=features)
yield feedback
new_spec = {}
for k in spec:
if k == 'pred_h':
assert spec[k] == (specs.Stage.HINT, specs.Location.NODE,
specs.Type.POINTER)
new_spec['pred'] = (specs.Stage.INPUT, specs.Location.NODE,
specs.Type.POINTER)
else:
new_spec[k] = spec[k]
return new_spec, _iterate()
def process_random_pos(sample_iterator, rng):
"""Randomize the `pos` input from a sampler.
The `pos` input is, by default, a scalar uniformly spaced between 0 and 1
across the nodes. The exception are string algorithms (naive_string_matcher,
kmp_string_matcher and lcs_length), where the `pos` sequence is split into
needle and haystack (or first and second string, for lcs_length). Here
we replace the uniformly spaced `pos` with an ordered sequence of random
scalars, or, for string algorithms, two ordered sequences of random scalars.
Args:
sample_iterator: An iterator producing samples with non-random `pos` inputs.
rng: Numpy random generator
Returns:
An iterator returning the samples with randomized `pos` inputs.
"""
def _iterate():
while True:
feedback = next(sample_iterator)
inputs = feedback.features.inputs
pos, = [x for x in inputs if x.name == 'pos']
batch_size, num_nodes = pos.data.shape
unsorted = rng.uniform(size=(batch_size, num_nodes))
new_pos = []
for i in range(batch_size): # we check one example at a time.
# We find if there are splits in the pos sequence, marked by zeros.
# We know there will always be at least 1 zero, if there's no split.
split, = np.where(pos.data[i] == 0)
split = np.concatenate([split, [num_nodes]])
# We construct the randomized pos by sorting the random values in each
# split and concatenating them.
new_pos.append(
np.concatenate([np.sort(unsorted[i, split[j]:split[j+1]])
for j in range(len(split) - 1)]))
pos.data = np.array(new_pos)
inputs = [(pos if x.name == 'pos' else x) for x in inputs]
features = feedback.features._replace(inputs=inputs)
feedback = feedback._replace(features=features)
yield feedback
return _iterate()