|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Context for Universal Value Function agents. |
|
|
|
A context specifies a list of contextual variables, each with |
|
own sampling and reward computation methods. |
|
|
|
Examples of contextual variables include |
|
goal states, reward combination vectors, etc. |
|
""" |
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
import numpy as np |
|
import tensorflow as tf |
|
from tf_agents import specs |
|
import gin.tf |
|
from utils import utils as uvf_utils |
|
|
|
|
|
@gin.configurable |
|
class Context(object): |
|
"""Base context.""" |
|
VAR_NAME = 'action' |
|
|
|
def __init__(self, |
|
tf_env, |
|
context_ranges=None, |
|
context_shapes=None, |
|
state_indices=None, |
|
variable_indices=None, |
|
gamma_index=None, |
|
settable_context=False, |
|
timers=None, |
|
samplers=None, |
|
reward_weights=None, |
|
reward_fn=None, |
|
random_sampler_mode='random', |
|
normalizers=None, |
|
context_transition_fn=None, |
|
context_multi_transition_fn=None, |
|
meta_action_every_n=None): |
|
self._tf_env = tf_env |
|
self.variable_indices = variable_indices |
|
self.gamma_index = gamma_index |
|
self._settable_context = settable_context |
|
self.timers = timers |
|
self._context_transition_fn = context_transition_fn |
|
self._context_multi_transition_fn = context_multi_transition_fn |
|
self._random_sampler_mode = random_sampler_mode |
|
|
|
|
|
self._obs_spec = self._tf_env.observation_spec() |
|
self._context_shapes = tuple([ |
|
shape if shape is not None else self._obs_spec.shape |
|
for shape in context_shapes |
|
]) |
|
self.context_specs = tuple([ |
|
specs.TensorSpec(dtype=self._obs_spec.dtype, shape=shape) |
|
for shape in self._context_shapes |
|
]) |
|
if context_ranges is not None: |
|
self.context_ranges = context_ranges |
|
else: |
|
self.context_ranges = [None] * len(self._context_shapes) |
|
|
|
self.context_as_action_specs = tuple([ |
|
specs.BoundedTensorSpec( |
|
shape=shape, |
|
dtype=(tf.float32 if self._obs_spec.dtype in |
|
[tf.float32, tf.float64] else self._obs_spec.dtype), |
|
minimum=context_range[0], |
|
maximum=context_range[-1]) |
|
for shape, context_range in zip(self._context_shapes, self.context_ranges) |
|
]) |
|
|
|
if state_indices is not None: |
|
self.state_indices = state_indices |
|
else: |
|
self.state_indices = [None] * len(self._context_shapes) |
|
if self.variable_indices is not None and self.n != len( |
|
self.variable_indices): |
|
raise ValueError( |
|
'variable_indices (%s) must have the same length as contexts (%s).' % |
|
(self.variable_indices, self.context_specs)) |
|
assert self.n == len(self.context_ranges) |
|
assert self.n == len(self.state_indices) |
|
|
|
|
|
self._sampler_fns = dict() |
|
self._samplers = dict() |
|
self._reward_fns = dict() |
|
|
|
|
|
self._add_custom_reward_fns() |
|
reward_weights = reward_weights or None |
|
self._reward_fn = self._make_reward_fn(reward_fn, reward_weights) |
|
|
|
|
|
self._add_custom_sampler_fns() |
|
for mode, sampler_fns in samplers.items(): |
|
self._make_sampler_fn(sampler_fns, mode) |
|
|
|
|
|
if normalizers is None: |
|
self._normalizers = [None] * len(self.context_specs) |
|
else: |
|
self._normalizers = [ |
|
normalizer(tf.zeros(shape=spec.shape, dtype=spec.dtype)) |
|
if normalizer is not None else None |
|
for normalizer, spec in zip(normalizers, self.context_specs) |
|
] |
|
assert self.n == len(self._normalizers) |
|
|
|
self.meta_action_every_n = meta_action_every_n |
|
|
|
|
|
self.context_vars = {} |
|
self.timer_vars = {} |
|
self.create_vars(self.VAR_NAME) |
|
self.t = tf.Variable( |
|
tf.zeros(shape=(), dtype=tf.int32), name='num_timer_steps') |
|
|
|
def _add_custom_reward_fns(self): |
|
pass |
|
|
|
def _add_custom_sampler_fns(self): |
|
pass |
|
|
|
def sample_random_contexts(self, batch_size): |
|
"""Sample random batch contexts.""" |
|
assert self._random_sampler_mode is not None |
|
return self.sample_contexts(self._random_sampler_mode, batch_size)[0] |
|
|
|
def sample_contexts(self, mode, batch_size, state=None, next_state=None, |
|
**kwargs): |
|
"""Sample a batch of contexts. |
|
|
|
Args: |
|
mode: A string representing the mode [`train`, `explore`, `eval`]. |
|
batch_size: Batch size. |
|
Returns: |
|
Two lists of [batch_size, num_context_dims] contexts. |
|
""" |
|
contexts, next_contexts = self._sampler_fns[mode]( |
|
batch_size, state=state, next_state=next_state, |
|
**kwargs) |
|
self._validate_contexts(contexts) |
|
self._validate_contexts(next_contexts) |
|
return contexts, next_contexts |
|
|
|
def compute_rewards(self, mode, states, actions, rewards, next_states, |
|
contexts): |
|
"""Compute context-based rewards. |
|
|
|
Args: |
|
mode: A string representing the mode ['uvf', 'task']. |
|
states: A [batch_size, num_state_dims] tensor. |
|
actions: A [batch_size, num_action_dims] tensor. |
|
rewards: A [batch_size] tensor representing unmodified rewards. |
|
next_states: A [batch_size, num_state_dims] tensor. |
|
contexts: A list of [batch_size, num_context_dims] tensors. |
|
Returns: |
|
A [batch_size] tensor representing rewards. |
|
""" |
|
return self._reward_fn(states, actions, rewards, next_states, |
|
contexts) |
|
|
|
def _make_reward_fn(self, reward_fns_list, reward_weights): |
|
"""Returns a fn that computes rewards. |
|
|
|
Args: |
|
reward_fns_list: A fn or a list of reward fns. |
|
mode: A string representing the operating mode. |
|
reward_weights: A list of reward weights. |
|
""" |
|
if not isinstance(reward_fns_list, (list, tuple)): |
|
reward_fns_list = [reward_fns_list] |
|
if reward_weights is None: |
|
reward_weights = [1.0] * len(reward_fns_list) |
|
assert len(reward_fns_list) == len(reward_weights) |
|
|
|
reward_fns_list = [ |
|
self._custom_reward_fns[fn] if isinstance(fn, (str,)) else fn |
|
for fn in reward_fns_list |
|
] |
|
|
|
def reward_fn(*args, **kwargs): |
|
"""Returns rewards, discounts.""" |
|
reward_tuples = [ |
|
reward_fn(*args, **kwargs) for reward_fn in reward_fns_list |
|
] |
|
rewards_list = [reward_tuple[0] for reward_tuple in reward_tuples] |
|
discounts_list = [reward_tuple[1] for reward_tuple in reward_tuples] |
|
ndims = max([r.shape.ndims for r in rewards_list]) |
|
if ndims > 1: |
|
for i in range(len(rewards_list)): |
|
for _ in range(rewards_list[i].shape.ndims - ndims): |
|
rewards_list[i] = tf.expand_dims(rewards_list[i], axis=-1) |
|
for _ in range(discounts_list[i].shape.ndims - ndims): |
|
discounts_list[i] = tf.expand_dims(discounts_list[i], axis=-1) |
|
rewards = tf.add_n( |
|
[r * tf.to_float(w) for r, w in zip(rewards_list, reward_weights)]) |
|
discounts = discounts_list[0] |
|
for d in discounts_list[1:]: |
|
discounts *= d |
|
|
|
return rewards, discounts |
|
|
|
return reward_fn |
|
|
|
def _make_sampler_fn(self, sampler_cls_list, mode): |
|
"""Returns a fn that samples a list of context vars. |
|
|
|
Args: |
|
sampler_cls_list: A list of sampler classes. |
|
mode: A string representing the operating mode. |
|
""" |
|
if not isinstance(sampler_cls_list, (list, tuple)): |
|
sampler_cls_list = [sampler_cls_list] |
|
|
|
self._samplers[mode] = [] |
|
sampler_fns = [] |
|
for spec, sampler in zip(self.context_specs, sampler_cls_list): |
|
if isinstance(sampler, (str,)): |
|
sampler_fn = self._custom_sampler_fns[sampler] |
|
else: |
|
sampler_fn = sampler(context_spec=spec) |
|
self._samplers[mode].append(sampler_fn) |
|
sampler_fns.append(sampler_fn) |
|
|
|
def batch_sampler_fn(batch_size, state=None, next_state=None, **kwargs): |
|
"""Sampler fn.""" |
|
contexts_tuples = [ |
|
sampler(batch_size, state=state, next_state=next_state, **kwargs) |
|
for sampler in sampler_fns] |
|
contexts = [c[0] for c in contexts_tuples] |
|
next_contexts = [c[1] for c in contexts_tuples] |
|
contexts = [ |
|
normalizer.update_apply(c) if normalizer is not None else c |
|
for normalizer, c in zip(self._normalizers, contexts) |
|
] |
|
next_contexts = [ |
|
normalizer.apply(c) if normalizer is not None else c |
|
for normalizer, c in zip(self._normalizers, next_contexts) |
|
] |
|
return contexts, next_contexts |
|
|
|
self._sampler_fns[mode] = batch_sampler_fn |
|
|
|
def set_env_context_op(self, context, disable_unnormalizer=False): |
|
"""Returns a TensorFlow op that sets the environment context. |
|
|
|
Args: |
|
context: A list of context Tensor variables. |
|
disable_unnormalizer: Disable unnormalization. |
|
Returns: |
|
A TensorFlow op that sets the environment context. |
|
""" |
|
ret_val = np.array(1.0, dtype=np.float32) |
|
if not self._settable_context: |
|
return tf.identity(ret_val) |
|
|
|
if not disable_unnormalizer: |
|
context = [ |
|
normalizer.unapply(tf.expand_dims(c, 0))[0] |
|
if normalizer is not None else c |
|
for normalizer, c in zip(self._normalizers, context) |
|
] |
|
|
|
def set_context_func(*env_context_values): |
|
tf.logging.info('[set_env_context_op] Setting gym environment context.') |
|
|
|
self.gym_env.set_context(*env_context_values) |
|
return ret_val |
|
|
|
|
|
with tf.name_scope('set_env_context'): |
|
set_op = tf.py_func(set_context_func, context, tf.float32, |
|
name='set_env_context_py_func') |
|
set_op.set_shape([]) |
|
return set_op |
|
|
|
def set_replay(self, replay): |
|
"""Set replay buffer for samplers. |
|
|
|
Args: |
|
replay: A replay buffer. |
|
""" |
|
for _, samplers in self._samplers.items(): |
|
for sampler in samplers: |
|
sampler.set_replay(replay) |
|
|
|
def get_clip_fns(self): |
|
"""Returns a list of clip fns for contexts. |
|
|
|
Returns: |
|
A list of fns that clip context tensors. |
|
""" |
|
clip_fns = [] |
|
for context_range in self.context_ranges: |
|
def clip_fn(var_, range_=context_range): |
|
"""Clip a tensor.""" |
|
if range_ is None: |
|
clipped_var = tf.identity(var_) |
|
elif isinstance(range_[0], (int, long, float, list, np.ndarray)): |
|
clipped_var = tf.clip_by_value( |
|
var_, |
|
range_[0], |
|
range_[1],) |
|
else: raise NotImplementedError(range_) |
|
return clipped_var |
|
clip_fns.append(clip_fn) |
|
return clip_fns |
|
|
|
def _validate_contexts(self, contexts): |
|
"""Validate if contexts have right specs. |
|
|
|
Args: |
|
contexts: A list of [batch_size, num_context_dim] tensors. |
|
Raises: |
|
ValueError: If shape or dtype mismatches that of spec. |
|
""" |
|
for i, (context, spec) in enumerate(zip(contexts, self.context_specs)): |
|
if context[0].shape != spec.shape: |
|
raise ValueError('contexts[%d] has invalid shape %s wrt spec shape %s' % |
|
(i, context[0].shape, spec.shape)) |
|
if context.dtype != spec.dtype: |
|
raise ValueError('contexts[%d] has invalid dtype %s wrt spec dtype %s' % |
|
(i, context.dtype, spec.dtype)) |
|
|
|
def context_multi_transition_fn(self, contexts, **kwargs): |
|
"""Returns multiple future contexts starting from a batch.""" |
|
assert self._context_multi_transition_fn |
|
return self._context_multi_transition_fn(contexts, None, None, **kwargs) |
|
|
|
def step(self, mode, agent=None, action_fn=None, **kwargs): |
|
"""Returns [next_contexts..., next_timer] list of ops. |
|
|
|
Args: |
|
mode: a string representing the mode=[train, explore, eval]. |
|
**kwargs: kwargs for context_transition_fn. |
|
Returns: |
|
a list of ops that set the context. |
|
""" |
|
if agent is None: |
|
ops = [] |
|
if self._context_transition_fn is not None: |
|
def sampler_fn(): |
|
samples = self.sample_contexts(mode, 1)[0] |
|
return [s[0] for s in samples] |
|
values = self._context_transition_fn(self.vars, self.t, sampler_fn, **kwargs) |
|
ops += [tf.assign(var, value) for var, value in zip(self.vars, values)] |
|
ops.append(tf.assign_add(self.t, 1)) |
|
return ops |
|
else: |
|
ops = agent.tf_context.step(mode, **kwargs) |
|
state = kwargs['state'] |
|
next_state = kwargs['next_state'] |
|
state_repr = kwargs['state_repr'] |
|
next_state_repr = kwargs['next_state_repr'] |
|
with tf.control_dependencies(ops): |
|
|
|
values = self._context_transition_fn(self.vars, self.t, None, |
|
state=state_repr, |
|
next_state=next_state_repr) |
|
|
|
low_level_context = [ |
|
tf.cond(tf.equal(self.t % self.meta_action_every_n, 0), |
|
lambda: tf.cast(action_fn(next_state, context=None), tf.float32), |
|
lambda: values)] |
|
ops = [tf.assign(var, value) |
|
for var, value in zip(self.vars, low_level_context)] |
|
with tf.control_dependencies(ops): |
|
return [tf.assign_add(self.t, 1)] |
|
return ops |
|
|
|
def reset(self, mode, agent=None, action_fn=None, state=None): |
|
"""Returns ops that reset the context. |
|
|
|
Args: |
|
mode: a string representing the mode=[train, explore, eval]. |
|
Returns: |
|
a list of ops that reset the context. |
|
""" |
|
if agent is None: |
|
values = self.sample_contexts(mode=mode, batch_size=1)[0] |
|
if values is None: |
|
return [] |
|
values = [value[0] for value in values] |
|
values[0] = uvf_utils.tf_print( |
|
values[0], |
|
values, |
|
message='context:reset, mode=%s' % mode, |
|
first_n=10, |
|
name='context:reset:%s' % mode) |
|
all_ops = [] |
|
for _, context_vars in sorted(self.context_vars.items()): |
|
ops = [tf.assign(var, value) for var, value in zip(context_vars, values)] |
|
all_ops += ops |
|
all_ops.append(self.set_env_context_op(values)) |
|
all_ops.append(tf.assign(self.t, 0)) |
|
return all_ops |
|
else: |
|
ops = agent.tf_context.reset(mode) |
|
|
|
|
|
|
|
for key, context_vars in sorted(self.context_vars.items()): |
|
ops += [tf.assign(var, tf.zeros_like(var)) for var, meta_var in |
|
zip(context_vars, agent.tf_context.context_vars[key])] |
|
|
|
ops.append(tf.assign(self.t, 0)) |
|
return ops |
|
|
|
def create_vars(self, name, agent=None): |
|
"""Create tf variables for contexts. |
|
|
|
Args: |
|
name: Name of the variables. |
|
Returns: |
|
A list of [num_context_dims] tensors. |
|
""" |
|
if agent is not None: |
|
meta_vars = agent.create_vars(name) |
|
else: |
|
meta_vars = {} |
|
assert name not in self.context_vars, ('Conflict! %s is already ' |
|
'initialized.') % name |
|
self.context_vars[name] = tuple([ |
|
tf.Variable( |
|
tf.zeros(shape=spec.shape, dtype=spec.dtype), |
|
name='%s_context_%d' % (name, i)) |
|
for i, spec in enumerate(self.context_specs) |
|
]) |
|
return self.context_vars[name], meta_vars |
|
|
|
@property |
|
def n(self): |
|
return len(self.context_specs) |
|
|
|
@property |
|
def vars(self): |
|
return self.context_vars[self.VAR_NAME] |
|
|
|
|
|
@property |
|
def gym_env(self): |
|
return self._tf_env.pyenv._gym_env |
|
|
|
@property |
|
def tf_env(self): |
|
return self._tf_env |
|
|
|
|