|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Samplers for Contexts. |
|
|
|
Each sampler class should define __call__(batch_size). |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import numpy as np |
|
import tensorflow as tf |
|
slim = tf.contrib.slim |
|
import gin.tf |
|
|
|
|
|
@gin.configurable |
|
class BaseSampler(object): |
|
"""Base sampler.""" |
|
|
|
def __init__(self, context_spec, context_range=None, k=2, scope='sampler'): |
|
"""Construct a base sampler. |
|
|
|
Args: |
|
context_spec: A context spec. |
|
context_range: A tuple of (minval, max), where minval, maxval are floats |
|
or Numpy arrays with the same shape as the context. |
|
scope: A string denoting scope. |
|
""" |
|
self._context_spec = context_spec |
|
self._context_range = context_range |
|
self._k = k |
|
self._scope = scope |
|
|
|
def __call__(self, batch_size, **kwargs): |
|
raise NotImplementedError |
|
|
|
def set_replay(self, replay=None): |
|
pass |
|
|
|
def _validate_contexts(self, contexts): |
|
"""Validate if contexts have right spec. |
|
|
|
Args: |
|
contexts: A [batch_size, num_contexts_dim] tensor. |
|
Raises: |
|
ValueError: If shape or dtype mismatches that of spec. |
|
""" |
|
if contexts[0].shape != self._context_spec.shape: |
|
raise ValueError('contexts has invalid shape %s wrt spec shape %s' % |
|
(contexts[0].shape, self._context_spec.shape)) |
|
if contexts.dtype != self._context_spec.dtype: |
|
raise ValueError('contexts has invalid dtype %s wrt spec dtype %s' % |
|
(contexts.dtype, self._context_spec.dtype)) |
|
|
|
|
|
@gin.configurable |
|
class ZeroSampler(BaseSampler): |
|
"""Zero sampler.""" |
|
|
|
def __call__(self, batch_size, **kwargs): |
|
"""Sample a batch of context. |
|
|
|
Args: |
|
batch_size: Batch size. |
|
Returns: |
|
Two [batch_size, num_context_dims] tensors. |
|
""" |
|
contexts = tf.zeros( |
|
dtype=self._context_spec.dtype, |
|
shape=[ |
|
batch_size, |
|
] + self._context_spec.shape.as_list()) |
|
return contexts, contexts |
|
|
|
|
|
@gin.configurable |
|
class BinarySampler(BaseSampler): |
|
"""Binary sampler.""" |
|
|
|
def __init__(self, probs=0.5, *args, **kwargs): |
|
"""Constructor.""" |
|
super(BinarySampler, self).__init__(*args, **kwargs) |
|
self._probs = probs |
|
|
|
def __call__(self, batch_size, **kwargs): |
|
"""Sample a batch of context.""" |
|
spec = self._context_spec |
|
contexts = tf.random_uniform( |
|
shape=[ |
|
batch_size, |
|
] + spec.shape.as_list(), dtype=tf.float32) |
|
contexts = tf.cast(tf.greater(contexts, self._probs), dtype=spec.dtype) |
|
return contexts, contexts |
|
|
|
|
|
@gin.configurable |
|
class RandomSampler(BaseSampler): |
|
"""Random sampler.""" |
|
|
|
def __call__(self, batch_size, **kwargs): |
|
"""Sample a batch of context. |
|
|
|
Args: |
|
batch_size: Batch size. |
|
Returns: |
|
Two [batch_size, num_context_dims] tensors. |
|
""" |
|
spec = self._context_spec |
|
context_range = self._context_range |
|
if isinstance(context_range[0], (int, float)): |
|
contexts = tf.random_uniform( |
|
shape=[ |
|
batch_size, |
|
] + spec.shape.as_list(), |
|
minval=context_range[0], |
|
maxval=context_range[1], |
|
dtype=spec.dtype) |
|
elif isinstance(context_range[0], (list, tuple, np.ndarray)): |
|
assert len(spec.shape.as_list()) == 1 |
|
assert spec.shape.as_list()[0] == len(context_range[0]) |
|
assert spec.shape.as_list()[0] == len(context_range[1]) |
|
contexts = tf.concat( |
|
[ |
|
tf.random_uniform( |
|
shape=[ |
|
batch_size, 1, |
|
] + spec.shape.as_list()[1:], |
|
minval=context_range[0][i], |
|
maxval=context_range[1][i], |
|
dtype=spec.dtype) for i in range(spec.shape.as_list()[0]) |
|
], |
|
axis=1) |
|
else: raise NotImplementedError(context_range) |
|
self._validate_contexts(contexts) |
|
state, next_state = kwargs['state'], kwargs['next_state'] |
|
if state is not None and next_state is not None: |
|
pass |
|
|
|
|
|
|
|
|
|
|
|
return contexts, contexts |
|
|
|
|
|
@gin.configurable |
|
class ScheduledSampler(BaseSampler): |
|
"""Scheduled sampler.""" |
|
|
|
def __init__(self, |
|
scope='default', |
|
values=None, |
|
scheduler='cycle', |
|
scheduler_params=None, |
|
*args, **kwargs): |
|
"""Construct sampler. |
|
|
|
Args: |
|
scope: Scope name. |
|
values: A list of numbers or [num_context_dim] Numpy arrays |
|
representing the values to cycle. |
|
scheduler: scheduler type. |
|
scheduler_params: scheduler parameters. |
|
*args: arguments. |
|
**kwargs: keyword arguments. |
|
""" |
|
super(ScheduledSampler, self).__init__(*args, **kwargs) |
|
self._scope = scope |
|
self._values = values |
|
self._scheduler = scheduler |
|
self._scheduler_params = scheduler_params or {} |
|
assert self._values is not None and len( |
|
self._values), 'must provide non-empty values.' |
|
self._n = len(self._values) |
|
|
|
self._count = 0 |
|
self._i = tf.Variable( |
|
tf.zeros(shape=(), dtype=tf.int32), |
|
name='%s-scheduled_sampler_%d' % (self._scope, self._count)) |
|
self._values = tf.constant(self._values, dtype=self._context_spec.dtype) |
|
|
|
def __call__(self, batch_size, **kwargs): |
|
"""Sample a batch of context. |
|
|
|
Args: |
|
batch_size: Batch size. |
|
Returns: |
|
Two [batch_size, num_context_dims] tensors. |
|
""" |
|
spec = self._context_spec |
|
next_op = self._next(self._i) |
|
with tf.control_dependencies([next_op]): |
|
value = self._values[self._i] |
|
if value.get_shape().as_list(): |
|
values = tf.tile( |
|
tf.expand_dims(value, 0), (batch_size,) + (1,) * spec.shape.ndims) |
|
else: |
|
values = value + tf.zeros( |
|
shape=[ |
|
batch_size, |
|
] + spec.shape.as_list(), dtype=spec.dtype) |
|
self._validate_contexts(values) |
|
self._count += 1 |
|
return values, values |
|
|
|
def _next(self, i): |
|
"""Return op that increments pointer to next value. |
|
|
|
Args: |
|
i: A tensorflow integer variable. |
|
Returns: |
|
Op that increments pointer. |
|
""" |
|
if self._scheduler == 'cycle': |
|
inc = ('inc' in self._scheduler_params and |
|
self._scheduler_params['inc']) or 1 |
|
return tf.assign(i, tf.mod(i+inc, self._n)) |
|
else: |
|
raise NotImplementedError(self._scheduler) |
|
|
|
|
|
@gin.configurable |
|
class ReplaySampler(BaseSampler): |
|
"""Replay sampler.""" |
|
|
|
def __init__(self, |
|
prefetch_queue_capacity=2, |
|
override_indices=None, |
|
state_indices=None, |
|
*args, |
|
**kwargs): |
|
"""Construct sampler. |
|
|
|
Args: |
|
prefetch_queue_capacity: Capacity for prefetch queue. |
|
override_indices: Override indices. |
|
state_indices: Select certain indices from state dimension. |
|
*args: arguments. |
|
**kwargs: keyword arguments. |
|
""" |
|
super(ReplaySampler, self).__init__(*args, **kwargs) |
|
self._prefetch_queue_capacity = prefetch_queue_capacity |
|
self._override_indices = override_indices |
|
self._state_indices = state_indices |
|
|
|
def set_replay(self, replay): |
|
"""Set replay. |
|
|
|
Args: |
|
replay: A replay buffer. |
|
""" |
|
self._replay = replay |
|
|
|
def __call__(self, batch_size, **kwargs): |
|
"""Sample a batch of context. |
|
|
|
Args: |
|
batch_size: Batch size. |
|
Returns: |
|
Two [batch_size, num_context_dims] tensors. |
|
""" |
|
batch = self._replay.GetRandomBatch(batch_size) |
|
next_states = batch[4] |
|
if self._prefetch_queue_capacity > 0: |
|
batch_queue = slim.prefetch_queue.prefetch_queue( |
|
[next_states], |
|
capacity=self._prefetch_queue_capacity, |
|
name='%s/batch_context_queue' % self._scope) |
|
next_states = batch_queue.dequeue() |
|
if self._override_indices is not None: |
|
assert self._context_range is not None and isinstance( |
|
self._context_range[0], (int, long, float)) |
|
next_states = tf.concat( |
|
[ |
|
tf.random_uniform( |
|
shape=next_states[:, :1].shape, |
|
minval=self._context_range[0], |
|
maxval=self._context_range[1], |
|
dtype=next_states.dtype) |
|
if i in self._override_indices else next_states[:, i:i + 1] |
|
for i in range(self._context_spec.shape.as_list()[0]) |
|
], |
|
axis=1) |
|
if self._state_indices is not None: |
|
next_states = tf.concat( |
|
[ |
|
next_states[:, i:i + 1] |
|
for i in range(self._context_spec.shape.as_list()[0]) |
|
], |
|
axis=1) |
|
self._validate_contexts(next_states) |
|
return next_states, next_states |
|
|
|
|
|
@gin.configurable |
|
class TimeSampler(BaseSampler): |
|
"""Time Sampler.""" |
|
|
|
def __init__(self, minval=0, maxval=1, timestep=-1, *args, **kwargs): |
|
"""Construct sampler. |
|
|
|
Args: |
|
minval: Min value integer. |
|
maxval: Max value integer. |
|
timestep: Time step between states and next_states. |
|
*args: arguments. |
|
**kwargs: keyword arguments. |
|
""" |
|
super(TimeSampler, self).__init__(*args, **kwargs) |
|
assert self._context_spec.shape.as_list() == [1] |
|
self._minval = minval |
|
self._maxval = maxval |
|
self._timestep = timestep |
|
|
|
def __call__(self, batch_size, **kwargs): |
|
"""Sample a batch of context. |
|
|
|
Args: |
|
batch_size: Batch size. |
|
Returns: |
|
Two [batch_size, num_context_dims] tensors. |
|
""" |
|
if self._maxval == self._minval: |
|
contexts = tf.constant( |
|
self._maxval, shape=[batch_size, 1], dtype=tf.int32) |
|
else: |
|
contexts = tf.random_uniform( |
|
shape=[batch_size, 1], |
|
dtype=tf.int32, |
|
maxval=self._maxval, |
|
minval=self._minval) |
|
next_contexts = tf.maximum(contexts + self._timestep, 0) |
|
|
|
return tf.cast( |
|
contexts, dtype=self._context_spec.dtype), tf.cast( |
|
next_contexts, dtype=self._context_spec.dtype) |
|
|
|
|
|
@gin.configurable |
|
class ConstantSampler(BaseSampler): |
|
"""Constant sampler.""" |
|
|
|
def __init__(self, value=None, *args, **kwargs): |
|
"""Construct sampler. |
|
|
|
Args: |
|
value: A list or Numpy array for values of the constant. |
|
*args: arguments. |
|
**kwargs: keyword arguments. |
|
""" |
|
super(ConstantSampler, self).__init__(*args, **kwargs) |
|
self._value = value |
|
|
|
def __call__(self, batch_size, **kwargs): |
|
"""Sample a batch of context. |
|
|
|
Args: |
|
batch_size: Batch size. |
|
Returns: |
|
Two [batch_size, num_context_dims] tensors. |
|
""" |
|
spec = self._context_spec |
|
value_ = tf.constant(self._value, shape=spec.shape, dtype=spec.dtype) |
|
values = tf.tile( |
|
tf.expand_dims(value_, 0), (batch_size,) + (1,) * spec.shape.ndims) |
|
self._validate_contexts(values) |
|
return values, values |
|
|
|
|
|
@gin.configurable |
|
class DirectionSampler(RandomSampler): |
|
"""Direction sampler.""" |
|
|
|
def __call__(self, batch_size, **kwargs): |
|
"""Sample a batch of context. |
|
|
|
Args: |
|
batch_size: Batch size. |
|
Returns: |
|
Two [batch_size, num_context_dims] tensors. |
|
""" |
|
spec = self._context_spec |
|
context_range = self._context_range |
|
if isinstance(context_range[0], (int, float)): |
|
contexts = tf.random_uniform( |
|
shape=[ |
|
batch_size, |
|
] + spec.shape.as_list(), |
|
minval=context_range[0], |
|
maxval=context_range[1], |
|
dtype=spec.dtype) |
|
elif isinstance(context_range[0], (list, tuple, np.ndarray)): |
|
assert len(spec.shape.as_list()) == 1 |
|
assert spec.shape.as_list()[0] == len(context_range[0]) |
|
assert spec.shape.as_list()[0] == len(context_range[1]) |
|
contexts = tf.concat( |
|
[ |
|
tf.random_uniform( |
|
shape=[ |
|
batch_size, 1, |
|
] + spec.shape.as_list()[1:], |
|
minval=context_range[0][i], |
|
maxval=context_range[1][i], |
|
dtype=spec.dtype) for i in range(spec.shape.as_list()[0]) |
|
], |
|
axis=1) |
|
else: raise NotImplementedError(context_range) |
|
self._validate_contexts(contexts) |
|
if 'sampler_fn' in kwargs: |
|
other_contexts = kwargs['sampler_fn']() |
|
else: |
|
other_contexts = contexts |
|
state, next_state = kwargs['state'], kwargs['next_state'] |
|
if state is not None and next_state is not None: |
|
my_context_range = (np.array(context_range[1]) - np.array(context_range[0])) / 2 * np.ones(spec.shape.as_list()) |
|
contexts = tf.concat( |
|
[0.1 * my_context_range[:self._k] * |
|
tf.random_normal(tf.shape(state[:, :self._k]), dtype=state.dtype) + |
|
tf.random_shuffle(state[:, :self._k]) - state[:, :self._k], |
|
other_contexts[:, self._k:]], 1) |
|
|
|
|
|
|
|
next_contexts = tf.concat( |
|
[state[:, :self._k] + contexts[:, :self._k] - next_state[:, :self._k], |
|
other_contexts[:, self._k:]], 1) |
|
next_contexts = contexts |
|
else: |
|
next_contexts = contexts |
|
return tf.stop_gradient(contexts), tf.stop_gradient(next_contexts) |
|
|