# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Samplers for Contexts. Each sampler class should define __call__(batch_size). """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf slim = tf.contrib.slim import gin.tf @gin.configurable class BaseSampler(object): """Base sampler.""" def __init__(self, context_spec, context_range=None, k=2, scope='sampler'): """Construct a base sampler. Args: context_spec: A context spec. context_range: A tuple of (minval, max), where minval, maxval are floats or Numpy arrays with the same shape as the context. scope: A string denoting scope. """ self._context_spec = context_spec self._context_range = context_range self._k = k self._scope = scope def __call__(self, batch_size, **kwargs): raise NotImplementedError def set_replay(self, replay=None): pass def _validate_contexts(self, contexts): """Validate if contexts have right spec. Args: contexts: A [batch_size, num_contexts_dim] tensor. Raises: ValueError: If shape or dtype mismatches that of spec. """ if contexts[0].shape != self._context_spec.shape: raise ValueError('contexts has invalid shape %s wrt spec shape %s' % (contexts[0].shape, self._context_spec.shape)) if contexts.dtype != self._context_spec.dtype: raise ValueError('contexts has invalid dtype %s wrt spec dtype %s' % (contexts.dtype, self._context_spec.dtype)) @gin.configurable class ZeroSampler(BaseSampler): """Zero sampler.""" def __call__(self, batch_size, **kwargs): """Sample a batch of context. Args: batch_size: Batch size. Returns: Two [batch_size, num_context_dims] tensors. """ contexts = tf.zeros( dtype=self._context_spec.dtype, shape=[ batch_size, ] + self._context_spec.shape.as_list()) return contexts, contexts @gin.configurable class BinarySampler(BaseSampler): """Binary sampler.""" def __init__(self, probs=0.5, *args, **kwargs): """Constructor.""" super(BinarySampler, self).__init__(*args, **kwargs) self._probs = probs def __call__(self, batch_size, **kwargs): """Sample a batch of context.""" spec = self._context_spec contexts = tf.random_uniform( shape=[ batch_size, ] + spec.shape.as_list(), dtype=tf.float32) contexts = tf.cast(tf.greater(contexts, self._probs), dtype=spec.dtype) return contexts, contexts @gin.configurable class RandomSampler(BaseSampler): """Random sampler.""" def __call__(self, batch_size, **kwargs): """Sample a batch of context. Args: batch_size: Batch size. Returns: Two [batch_size, num_context_dims] tensors. """ spec = self._context_spec context_range = self._context_range if isinstance(context_range[0], (int, float)): contexts = tf.random_uniform( shape=[ batch_size, ] + spec.shape.as_list(), minval=context_range[0], maxval=context_range[1], dtype=spec.dtype) elif isinstance(context_range[0], (list, tuple, np.ndarray)): assert len(spec.shape.as_list()) == 1 assert spec.shape.as_list()[0] == len(context_range[0]) assert spec.shape.as_list()[0] == len(context_range[1]) contexts = tf.concat( [ tf.random_uniform( shape=[ batch_size, 1, ] + spec.shape.as_list()[1:], minval=context_range[0][i], maxval=context_range[1][i], dtype=spec.dtype) for i in range(spec.shape.as_list()[0]) ], axis=1) else: raise NotImplementedError(context_range) self._validate_contexts(contexts) state, next_state = kwargs['state'], kwargs['next_state'] if state is not None and next_state is not None: pass #contexts = tf.concat( # [tf.random_normal(tf.shape(state[:, :self._k]), dtype=tf.float64) + # tf.random_shuffle(state[:, :self._k]), # contexts[:, self._k:]], 1) return contexts, contexts @gin.configurable class ScheduledSampler(BaseSampler): """Scheduled sampler.""" def __init__(self, scope='default', values=None, scheduler='cycle', scheduler_params=None, *args, **kwargs): """Construct sampler. Args: scope: Scope name. values: A list of numbers or [num_context_dim] Numpy arrays representing the values to cycle. scheduler: scheduler type. scheduler_params: scheduler parameters. *args: arguments. **kwargs: keyword arguments. """ super(ScheduledSampler, self).__init__(*args, **kwargs) self._scope = scope self._values = values self._scheduler = scheduler self._scheduler_params = scheduler_params or {} assert self._values is not None and len( self._values), 'must provide non-empty values.' self._n = len(self._values) # TODO(shanegu): move variable creation outside. resolve tf.cond problem. self._count = 0 self._i = tf.Variable( tf.zeros(shape=(), dtype=tf.int32), name='%s-scheduled_sampler_%d' % (self._scope, self._count)) self._values = tf.constant(self._values, dtype=self._context_spec.dtype) def __call__(self, batch_size, **kwargs): """Sample a batch of context. Args: batch_size: Batch size. Returns: Two [batch_size, num_context_dims] tensors. """ spec = self._context_spec next_op = self._next(self._i) with tf.control_dependencies([next_op]): value = self._values[self._i] if value.get_shape().as_list(): values = tf.tile( tf.expand_dims(value, 0), (batch_size,) + (1,) * spec.shape.ndims) else: values = value + tf.zeros( shape=[ batch_size, ] + spec.shape.as_list(), dtype=spec.dtype) self._validate_contexts(values) self._count += 1 return values, values def _next(self, i): """Return op that increments pointer to next value. Args: i: A tensorflow integer variable. Returns: Op that increments pointer. """ if self._scheduler == 'cycle': inc = ('inc' in self._scheduler_params and self._scheduler_params['inc']) or 1 return tf.assign(i, tf.mod(i+inc, self._n)) else: raise NotImplementedError(self._scheduler) @gin.configurable class ReplaySampler(BaseSampler): """Replay sampler.""" def __init__(self, prefetch_queue_capacity=2, override_indices=None, state_indices=None, *args, **kwargs): """Construct sampler. Args: prefetch_queue_capacity: Capacity for prefetch queue. override_indices: Override indices. state_indices: Select certain indices from state dimension. *args: arguments. **kwargs: keyword arguments. """ super(ReplaySampler, self).__init__(*args, **kwargs) self._prefetch_queue_capacity = prefetch_queue_capacity self._override_indices = override_indices self._state_indices = state_indices def set_replay(self, replay): """Set replay. Args: replay: A replay buffer. """ self._replay = replay def __call__(self, batch_size, **kwargs): """Sample a batch of context. Args: batch_size: Batch size. Returns: Two [batch_size, num_context_dims] tensors. """ batch = self._replay.GetRandomBatch(batch_size) next_states = batch[4] if self._prefetch_queue_capacity > 0: batch_queue = slim.prefetch_queue.prefetch_queue( [next_states], capacity=self._prefetch_queue_capacity, name='%s/batch_context_queue' % self._scope) next_states = batch_queue.dequeue() if self._override_indices is not None: assert self._context_range is not None and isinstance( self._context_range[0], (int, long, float)) next_states = tf.concat( [ tf.random_uniform( shape=next_states[:, :1].shape, minval=self._context_range[0], maxval=self._context_range[1], dtype=next_states.dtype) if i in self._override_indices else next_states[:, i:i + 1] for i in range(self._context_spec.shape.as_list()[0]) ], axis=1) if self._state_indices is not None: next_states = tf.concat( [ next_states[:, i:i + 1] for i in range(self._context_spec.shape.as_list()[0]) ], axis=1) self._validate_contexts(next_states) return next_states, next_states @gin.configurable class TimeSampler(BaseSampler): """Time Sampler.""" def __init__(self, minval=0, maxval=1, timestep=-1, *args, **kwargs): """Construct sampler. Args: minval: Min value integer. maxval: Max value integer. timestep: Time step between states and next_states. *args: arguments. **kwargs: keyword arguments. """ super(TimeSampler, self).__init__(*args, **kwargs) assert self._context_spec.shape.as_list() == [1] self._minval = minval self._maxval = maxval self._timestep = timestep def __call__(self, batch_size, **kwargs): """Sample a batch of context. Args: batch_size: Batch size. Returns: Two [batch_size, num_context_dims] tensors. """ if self._maxval == self._minval: contexts = tf.constant( self._maxval, shape=[batch_size, 1], dtype=tf.int32) else: contexts = tf.random_uniform( shape=[batch_size, 1], dtype=tf.int32, maxval=self._maxval, minval=self._minval) next_contexts = tf.maximum(contexts + self._timestep, 0) return tf.cast( contexts, dtype=self._context_spec.dtype), tf.cast( next_contexts, dtype=self._context_spec.dtype) @gin.configurable class ConstantSampler(BaseSampler): """Constant sampler.""" def __init__(self, value=None, *args, **kwargs): """Construct sampler. Args: value: A list or Numpy array for values of the constant. *args: arguments. **kwargs: keyword arguments. """ super(ConstantSampler, self).__init__(*args, **kwargs) self._value = value def __call__(self, batch_size, **kwargs): """Sample a batch of context. Args: batch_size: Batch size. Returns: Two [batch_size, num_context_dims] tensors. """ spec = self._context_spec value_ = tf.constant(self._value, shape=spec.shape, dtype=spec.dtype) values = tf.tile( tf.expand_dims(value_, 0), (batch_size,) + (1,) * spec.shape.ndims) self._validate_contexts(values) return values, values @gin.configurable class DirectionSampler(RandomSampler): """Direction sampler.""" def __call__(self, batch_size, **kwargs): """Sample a batch of context. Args: batch_size: Batch size. Returns: Two [batch_size, num_context_dims] tensors. """ spec = self._context_spec context_range = self._context_range if isinstance(context_range[0], (int, float)): contexts = tf.random_uniform( shape=[ batch_size, ] + spec.shape.as_list(), minval=context_range[0], maxval=context_range[1], dtype=spec.dtype) elif isinstance(context_range[0], (list, tuple, np.ndarray)): assert len(spec.shape.as_list()) == 1 assert spec.shape.as_list()[0] == len(context_range[0]) assert spec.shape.as_list()[0] == len(context_range[1]) contexts = tf.concat( [ tf.random_uniform( shape=[ batch_size, 1, ] + spec.shape.as_list()[1:], minval=context_range[0][i], maxval=context_range[1][i], dtype=spec.dtype) for i in range(spec.shape.as_list()[0]) ], axis=1) else: raise NotImplementedError(context_range) self._validate_contexts(contexts) if 'sampler_fn' in kwargs: other_contexts = kwargs['sampler_fn']() else: other_contexts = contexts state, next_state = kwargs['state'], kwargs['next_state'] if state is not None and next_state is not None: my_context_range = (np.array(context_range[1]) - np.array(context_range[0])) / 2 * np.ones(spec.shape.as_list()) contexts = tf.concat( [0.1 * my_context_range[:self._k] * tf.random_normal(tf.shape(state[:, :self._k]), dtype=state.dtype) + tf.random_shuffle(state[:, :self._k]) - state[:, :self._k], other_contexts[:, self._k:]], 1) #contexts = tf.Print(contexts, # [contexts, tf.reduce_max(contexts, 0), # tf.reduce_min(state, 0), tf.reduce_max(state, 0)], 'contexts', summarize=15) next_contexts = tf.concat( #LALA [state[:, :self._k] + contexts[:, :self._k] - next_state[:, :self._k], other_contexts[:, self._k:]], 1) next_contexts = contexts #LALA cosine else: next_contexts = contexts return tf.stop_gradient(contexts), tf.stop_gradient(next_contexts)