# Copyright 2018 The TensorFlow Authors All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """Context for Universal Value Function agents. A context specifies a list of contextual variables, each with own sampling and reward computation methods. Examples of contextual variables include goal states, reward combination vectors, etc. """ from __future__ import absolute_import from __future__ import division from __future__ import print_function import numpy as np import tensorflow as tf from tf_agents import specs import gin.tf from utils import utils as uvf_utils @gin.configurable class Context(object): """Base context.""" VAR_NAME = 'action' def __init__(self, tf_env, context_ranges=None, context_shapes=None, state_indices=None, variable_indices=None, gamma_index=None, settable_context=False, timers=None, samplers=None, reward_weights=None, reward_fn=None, random_sampler_mode='random', normalizers=None, context_transition_fn=None, context_multi_transition_fn=None, meta_action_every_n=None): self._tf_env = tf_env self.variable_indices = variable_indices self.gamma_index = gamma_index self._settable_context = settable_context self.timers = timers self._context_transition_fn = context_transition_fn self._context_multi_transition_fn = context_multi_transition_fn self._random_sampler_mode = random_sampler_mode # assign specs self._obs_spec = self._tf_env.observation_spec() self._context_shapes = tuple([ shape if shape is not None else self._obs_spec.shape for shape in context_shapes ]) self.context_specs = tuple([ specs.TensorSpec(dtype=self._obs_spec.dtype, shape=shape) for shape in self._context_shapes ]) if context_ranges is not None: self.context_ranges = context_ranges else: self.context_ranges = [None] * len(self._context_shapes) self.context_as_action_specs = tuple([ specs.BoundedTensorSpec( shape=shape, dtype=(tf.float32 if self._obs_spec.dtype in [tf.float32, tf.float64] else self._obs_spec.dtype), minimum=context_range[0], maximum=context_range[-1]) for shape, context_range in zip(self._context_shapes, self.context_ranges) ]) if state_indices is not None: self.state_indices = state_indices else: self.state_indices = [None] * len(self._context_shapes) if self.variable_indices is not None and self.n != len( self.variable_indices): raise ValueError( 'variable_indices (%s) must have the same length as contexts (%s).' % (self.variable_indices, self.context_specs)) assert self.n == len(self.context_ranges) assert self.n == len(self.state_indices) # assign reward/sampler fns self._sampler_fns = dict() self._samplers = dict() self._reward_fns = dict() # assign reward fns self._add_custom_reward_fns() reward_weights = reward_weights or None self._reward_fn = self._make_reward_fn(reward_fn, reward_weights) # assign samplers self._add_custom_sampler_fns() for mode, sampler_fns in samplers.items(): self._make_sampler_fn(sampler_fns, mode) # create normalizers if normalizers is None: self._normalizers = [None] * len(self.context_specs) else: self._normalizers = [ normalizer(tf.zeros(shape=spec.shape, dtype=spec.dtype)) if normalizer is not None else None for normalizer, spec in zip(normalizers, self.context_specs) ] assert self.n == len(self._normalizers) self.meta_action_every_n = meta_action_every_n # create vars self.context_vars = {} self.timer_vars = {} self.create_vars(self.VAR_NAME) self.t = tf.Variable( tf.zeros(shape=(), dtype=tf.int32), name='num_timer_steps') def _add_custom_reward_fns(self): pass def _add_custom_sampler_fns(self): pass def sample_random_contexts(self, batch_size): """Sample random batch contexts.""" assert self._random_sampler_mode is not None return self.sample_contexts(self._random_sampler_mode, batch_size)[0] def sample_contexts(self, mode, batch_size, state=None, next_state=None, **kwargs): """Sample a batch of contexts. Args: mode: A string representing the mode [`train`, `explore`, `eval`]. batch_size: Batch size. Returns: Two lists of [batch_size, num_context_dims] contexts. """ contexts, next_contexts = self._sampler_fns[mode]( batch_size, state=state, next_state=next_state, **kwargs) self._validate_contexts(contexts) self._validate_contexts(next_contexts) return contexts, next_contexts def compute_rewards(self, mode, states, actions, rewards, next_states, contexts): """Compute context-based rewards. Args: mode: A string representing the mode ['uvf', 'task']. states: A [batch_size, num_state_dims] tensor. actions: A [batch_size, num_action_dims] tensor. rewards: A [batch_size] tensor representing unmodified rewards. next_states: A [batch_size, num_state_dims] tensor. contexts: A list of [batch_size, num_context_dims] tensors. Returns: A [batch_size] tensor representing rewards. """ return self._reward_fn(states, actions, rewards, next_states, contexts) def _make_reward_fn(self, reward_fns_list, reward_weights): """Returns a fn that computes rewards. Args: reward_fns_list: A fn or a list of reward fns. mode: A string representing the operating mode. reward_weights: A list of reward weights. """ if not isinstance(reward_fns_list, (list, tuple)): reward_fns_list = [reward_fns_list] if reward_weights is None: reward_weights = [1.0] * len(reward_fns_list) assert len(reward_fns_list) == len(reward_weights) reward_fns_list = [ self._custom_reward_fns[fn] if isinstance(fn, (str,)) else fn for fn in reward_fns_list ] def reward_fn(*args, **kwargs): """Returns rewards, discounts.""" reward_tuples = [ reward_fn(*args, **kwargs) for reward_fn in reward_fns_list ] rewards_list = [reward_tuple[0] for reward_tuple in reward_tuples] discounts_list = [reward_tuple[1] for reward_tuple in reward_tuples] ndims = max([r.shape.ndims for r in rewards_list]) if ndims > 1: # expand reward shapes to allow broadcasting for i in range(len(rewards_list)): for _ in range(rewards_list[i].shape.ndims - ndims): rewards_list[i] = tf.expand_dims(rewards_list[i], axis=-1) for _ in range(discounts_list[i].shape.ndims - ndims): discounts_list[i] = tf.expand_dims(discounts_list[i], axis=-1) rewards = tf.add_n( [r * tf.to_float(w) for r, w in zip(rewards_list, reward_weights)]) discounts = discounts_list[0] for d in discounts_list[1:]: discounts *= d return rewards, discounts return reward_fn def _make_sampler_fn(self, sampler_cls_list, mode): """Returns a fn that samples a list of context vars. Args: sampler_cls_list: A list of sampler classes. mode: A string representing the operating mode. """ if not isinstance(sampler_cls_list, (list, tuple)): sampler_cls_list = [sampler_cls_list] self._samplers[mode] = [] sampler_fns = [] for spec, sampler in zip(self.context_specs, sampler_cls_list): if isinstance(sampler, (str,)): sampler_fn = self._custom_sampler_fns[sampler] else: sampler_fn = sampler(context_spec=spec) self._samplers[mode].append(sampler_fn) sampler_fns.append(sampler_fn) def batch_sampler_fn(batch_size, state=None, next_state=None, **kwargs): """Sampler fn.""" contexts_tuples = [ sampler(batch_size, state=state, next_state=next_state, **kwargs) for sampler in sampler_fns] contexts = [c[0] for c in contexts_tuples] next_contexts = [c[1] for c in contexts_tuples] contexts = [ normalizer.update_apply(c) if normalizer is not None else c for normalizer, c in zip(self._normalizers, contexts) ] next_contexts = [ normalizer.apply(c) if normalizer is not None else c for normalizer, c in zip(self._normalizers, next_contexts) ] return contexts, next_contexts self._sampler_fns[mode] = batch_sampler_fn def set_env_context_op(self, context, disable_unnormalizer=False): """Returns a TensorFlow op that sets the environment context. Args: context: A list of context Tensor variables. disable_unnormalizer: Disable unnormalization. Returns: A TensorFlow op that sets the environment context. """ ret_val = np.array(1.0, dtype=np.float32) if not self._settable_context: return tf.identity(ret_val) if not disable_unnormalizer: context = [ normalizer.unapply(tf.expand_dims(c, 0))[0] if normalizer is not None else c for normalizer, c in zip(self._normalizers, context) ] def set_context_func(*env_context_values): tf.logging.info('[set_env_context_op] Setting gym environment context.') # pylint: disable=protected-access self.gym_env.set_context(*env_context_values) return ret_val # pylint: enable=protected-access with tf.name_scope('set_env_context'): set_op = tf.py_func(set_context_func, context, tf.float32, name='set_env_context_py_func') set_op.set_shape([]) return set_op def set_replay(self, replay): """Set replay buffer for samplers. Args: replay: A replay buffer. """ for _, samplers in self._samplers.items(): for sampler in samplers: sampler.set_replay(replay) def get_clip_fns(self): """Returns a list of clip fns for contexts. Returns: A list of fns that clip context tensors. """ clip_fns = [] for context_range in self.context_ranges: def clip_fn(var_, range_=context_range): """Clip a tensor.""" if range_ is None: clipped_var = tf.identity(var_) elif isinstance(range_[0], (int, long, float, list, np.ndarray)): clipped_var = tf.clip_by_value( var_, range_[0], range_[1],) else: raise NotImplementedError(range_) return clipped_var clip_fns.append(clip_fn) return clip_fns def _validate_contexts(self, contexts): """Validate if contexts have right specs. Args: contexts: A list of [batch_size, num_context_dim] tensors. Raises: ValueError: If shape or dtype mismatches that of spec. """ for i, (context, spec) in enumerate(zip(contexts, self.context_specs)): if context[0].shape != spec.shape: raise ValueError('contexts[%d] has invalid shape %s wrt spec shape %s' % (i, context[0].shape, spec.shape)) if context.dtype != spec.dtype: raise ValueError('contexts[%d] has invalid dtype %s wrt spec dtype %s' % (i, context.dtype, spec.dtype)) def context_multi_transition_fn(self, contexts, **kwargs): """Returns multiple future contexts starting from a batch.""" assert self._context_multi_transition_fn return self._context_multi_transition_fn(contexts, None, None, **kwargs) def step(self, mode, agent=None, action_fn=None, **kwargs): """Returns [next_contexts..., next_timer] list of ops. Args: mode: a string representing the mode=[train, explore, eval]. **kwargs: kwargs for context_transition_fn. Returns: a list of ops that set the context. """ if agent is None: ops = [] if self._context_transition_fn is not None: def sampler_fn(): samples = self.sample_contexts(mode, 1)[0] return [s[0] for s in samples] values = self._context_transition_fn(self.vars, self.t, sampler_fn, **kwargs) ops += [tf.assign(var, value) for var, value in zip(self.vars, values)] ops.append(tf.assign_add(self.t, 1)) # increment timer return ops else: ops = agent.tf_context.step(mode, **kwargs) state = kwargs['state'] next_state = kwargs['next_state'] state_repr = kwargs['state_repr'] next_state_repr = kwargs['next_state_repr'] with tf.control_dependencies(ops): # Step high level context before computing low level one. # Get the context transition function output. values = self._context_transition_fn(self.vars, self.t, None, state=state_repr, next_state=next_state_repr) # Select a new goal every C steps, otherwise use context transition. low_level_context = [ tf.cond(tf.equal(self.t % self.meta_action_every_n, 0), lambda: tf.cast(action_fn(next_state, context=None), tf.float32), lambda: values)] ops = [tf.assign(var, value) for var, value in zip(self.vars, low_level_context)] with tf.control_dependencies(ops): return [tf.assign_add(self.t, 1)] # increment timer return ops def reset(self, mode, agent=None, action_fn=None, state=None): """Returns ops that reset the context. Args: mode: a string representing the mode=[train, explore, eval]. Returns: a list of ops that reset the context. """ if agent is None: values = self.sample_contexts(mode=mode, batch_size=1)[0] if values is None: return [] values = [value[0] for value in values] values[0] = uvf_utils.tf_print( values[0], values, message='context:reset, mode=%s' % mode, first_n=10, name='context:reset:%s' % mode) all_ops = [] for _, context_vars in sorted(self.context_vars.items()): ops = [tf.assign(var, value) for var, value in zip(context_vars, values)] all_ops += ops all_ops.append(self.set_env_context_op(values)) all_ops.append(tf.assign(self.t, 0)) # reset timer return all_ops else: ops = agent.tf_context.reset(mode) # NOTE: The code is currently written in such a way that the higher level # policy does not provide a low-level context until the second # observation. Insead, we just zero-out low-level contexts. for key, context_vars in sorted(self.context_vars.items()): ops += [tf.assign(var, tf.zeros_like(var)) for var, meta_var in zip(context_vars, agent.tf_context.context_vars[key])] ops.append(tf.assign(self.t, 0)) # reset timer return ops def create_vars(self, name, agent=None): """Create tf variables for contexts. Args: name: Name of the variables. Returns: A list of [num_context_dims] tensors. """ if agent is not None: meta_vars = agent.create_vars(name) else: meta_vars = {} assert name not in self.context_vars, ('Conflict! %s is already ' 'initialized.') % name self.context_vars[name] = tuple([ tf.Variable( tf.zeros(shape=spec.shape, dtype=spec.dtype), name='%s_context_%d' % (name, i)) for i, spec in enumerate(self.context_specs) ]) return self.context_vars[name], meta_vars @property def n(self): return len(self.context_specs) @property def vars(self): return self.context_vars[self.VAR_NAME] # pylint: disable=protected-access @property def gym_env(self): return self._tf_env.pyenv._gym_env @property def tf_env(self): return self._tf_env # pylint: enable=protected-access