|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""TensorFlow utility functions. |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
from copy import deepcopy |
|
import tensorflow as tf |
|
from tf_agents import specs |
|
from tf_agents.utils import common |
|
|
|
_tf_print_counts = dict() |
|
_tf_print_running_sums = dict() |
|
_tf_print_running_counts = dict() |
|
_tf_print_ids = 0 |
|
|
|
|
|
def get_contextual_env_base(env_base, begin_ops=None, end_ops=None): |
|
"""Wrap env_base with additional tf ops.""" |
|
|
|
def init(self_, env_base): |
|
self_._env_base = env_base |
|
attribute_list = ["_render_mode", "_gym_env"] |
|
for attribute in attribute_list: |
|
if hasattr(env_base, attribute): |
|
setattr(self_, attribute, getattr(env_base, attribute)) |
|
if hasattr(env_base, "physics"): |
|
self_._physics = env_base.physics |
|
elif hasattr(env_base, "gym"): |
|
class Physics(object): |
|
def render(self, *args, **kwargs): |
|
return env_base.gym.render("rgb_array") |
|
physics = Physics() |
|
self_._physics = physics |
|
self_.physics = physics |
|
def set_sess(self_, sess): |
|
self_._sess = sess |
|
if hasattr(self_._env_base, "set_sess"): |
|
self_._env_base.set_sess(sess) |
|
def begin_episode(self_): |
|
self_._env_base.reset() |
|
if begin_ops is not None: |
|
self_._sess.run(begin_ops) |
|
def end_episode(self_): |
|
self_._env_base.reset() |
|
if end_ops is not None: |
|
self_._sess.run(end_ops) |
|
return type("ContextualEnvBase", (env_base.__class__,), dict( |
|
__init__=init, |
|
set_sess=set_sess, |
|
begin_episode=begin_episode, |
|
end_episode=end_episode, |
|
))(env_base) |
|
|
|
|
|
|
|
def merge_specs(specs_): |
|
"""Merge TensorSpecs. |
|
|
|
Args: |
|
specs_: List of TensorSpecs to be merged. |
|
Returns: |
|
a TensorSpec: a merged TensorSpec. |
|
""" |
|
shape = specs_[0].shape |
|
dtype = specs_[0].dtype |
|
name = specs_[0].name |
|
for spec in specs_[1:]: |
|
assert shape[1:] == spec.shape[1:], "incompatible shapes: %s, %s" % ( |
|
shape, spec.shape) |
|
assert dtype == spec.dtype, "incompatible dtypes: %s, %s" % ( |
|
dtype, spec.dtype) |
|
shape = merge_shapes((shape, spec.shape), axis=0) |
|
return specs.TensorSpec( |
|
shape=shape, |
|
dtype=dtype, |
|
name=name, |
|
) |
|
|
|
|
|
def merge_shapes(shapes, axis=0): |
|
"""Merge TensorShapes. |
|
|
|
Args: |
|
shapes: List of TensorShapes to be merged. |
|
axis: optional, the axis to merge shaped. |
|
Returns: |
|
a TensorShape: a merged TensorShape. |
|
""" |
|
assert len(shapes) > 1 |
|
dims = deepcopy(shapes[0].dims) |
|
for shape in shapes[1:]: |
|
assert shapes[0].ndims == shape.ndims |
|
dims[axis] += shape.dims[axis] |
|
return tf.TensorShape(dims=dims) |
|
|
|
|
|
def get_all_vars(ignore_scopes=None): |
|
"""Get all tf variables in scope. |
|
|
|
Args: |
|
ignore_scopes: A list of scope names to ignore. |
|
Returns: |
|
A list of all tf variables in scope. |
|
""" |
|
all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES) |
|
all_vars = [var for var in all_vars if ignore_scopes is None or not |
|
any(var.name.startswith(scope) for scope in ignore_scopes)] |
|
return all_vars |
|
|
|
|
|
def clip(tensor, range_=None): |
|
"""Return a tf op which clips tensor according to range_. |
|
|
|
Args: |
|
tensor: A Tensor to be clipped. |
|
range_: None, or a tuple representing (minval, maxval) |
|
Returns: |
|
A clipped Tensor. |
|
""" |
|
if range_ is None: |
|
return tf.identity(tensor) |
|
elif isinstance(range_, (tuple, list)): |
|
assert len(range_) == 2 |
|
return tf.clip_by_value(tensor, range_[0], range_[1]) |
|
else: raise NotImplementedError("Unacceptable range input: %r" % range_) |
|
|
|
|
|
def clip_to_bounds(value, minimum, maximum): |
|
"""Clips value to be between minimum and maximum. |
|
|
|
Args: |
|
value: (tensor) value to be clipped. |
|
minimum: (numpy float array) minimum value to clip to. |
|
maximum: (numpy float array) maximum value to clip to. |
|
Returns: |
|
clipped_value: (tensor) `value` clipped to between `minimum` and `maximum`. |
|
""" |
|
value = tf.minimum(value, maximum) |
|
return tf.maximum(value, minimum) |
|
|
|
|
|
clip_to_spec = common.clip_to_spec |
|
def _clip_to_spec(value, spec): |
|
"""Clips value to a given bounded tensor spec. |
|
|
|
Args: |
|
value: (tensor) value to be clipped. |
|
spec: (BoundedTensorSpec) spec containing min. and max. values for clipping. |
|
Returns: |
|
clipped_value: (tensor) `value` clipped to be compatible with `spec`. |
|
""" |
|
return clip_to_bounds(value, spec.minimum, spec.maximum) |
|
|
|
|
|
join_scope = common.join_scope |
|
def _join_scope(parent_scope, child_scope): |
|
"""Joins a parent and child scope using `/`, checking for empty/none. |
|
|
|
Args: |
|
parent_scope: (string) parent/prefix scope. |
|
child_scope: (string) child/suffix scope. |
|
Returns: |
|
joined scope: (string) parent and child scopes joined by /. |
|
""" |
|
if not parent_scope: |
|
return child_scope |
|
if not child_scope: |
|
return parent_scope |
|
return '/'.join([parent_scope, child_scope]) |
|
|
|
|
|
def assign_vars(vars_, values): |
|
"""Returns the update ops for assigning a list of vars. |
|
|
|
Args: |
|
vars_: A list of variables. |
|
values: A list of tensors representing new values. |
|
Returns: |
|
A list of update ops for the variables. |
|
""" |
|
return [var.assign(value) for var, value in zip(vars_, values)] |
|
|
|
|
|
def identity_vars(vars_): |
|
"""Return the identity ops for a list of tensors. |
|
|
|
Args: |
|
vars_: A list of tensors. |
|
Returns: |
|
A list of identity ops. |
|
""" |
|
return [tf.identity(var) for var in vars_] |
|
|
|
|
|
def tile(var, batch_size=1): |
|
"""Return tiled tensor. |
|
|
|
Args: |
|
var: A tensor representing the state. |
|
batch_size: Batch size. |
|
Returns: |
|
A tensor with shape [batch_size,] + var.shape. |
|
""" |
|
batch_var = tf.tile( |
|
tf.expand_dims(var, 0), |
|
(batch_size,) + (1,) * var.get_shape().ndims) |
|
return batch_var |
|
|
|
|
|
def batch_list(vars_list): |
|
"""Batch a list of variables. |
|
|
|
Args: |
|
vars_list: A list of tensor variables. |
|
Returns: |
|
A list of tensor variables with additional first dimension. |
|
""" |
|
return [tf.expand_dims(var, 0) for var in vars_list] |
|
|
|
|
|
def tf_print(op, |
|
tensors, |
|
message="", |
|
first_n=-1, |
|
name=None, |
|
sub_messages=None, |
|
print_freq=-1, |
|
include_count=True): |
|
"""tf.Print, but to stdout.""" |
|
|
|
global _tf_print_ids |
|
_tf_print_ids += 1 |
|
name = _tf_print_ids |
|
_tf_print_counts[name] = 0 |
|
if print_freq > 0: |
|
_tf_print_running_sums[name] = [0 for _ in tensors] |
|
_tf_print_running_counts[name] = 0 |
|
def print_message(*xs): |
|
"""print message fn.""" |
|
_tf_print_counts[name] += 1 |
|
if print_freq > 0: |
|
for i, x in enumerate(xs): |
|
_tf_print_running_sums[name][i] += x |
|
_tf_print_running_counts[name] += 1 |
|
if (print_freq <= 0 or _tf_print_running_counts[name] >= print_freq) and ( |
|
first_n < 0 or _tf_print_counts[name] <= first_n): |
|
for i, x in enumerate(xs): |
|
if print_freq > 0: |
|
del x |
|
x = _tf_print_running_sums[name][i]/_tf_print_running_counts[name] |
|
if sub_messages is None: |
|
sub_message = str(i) |
|
else: |
|
sub_message = sub_messages[i] |
|
log_message = "%s, %s" % (message, sub_message) |
|
if include_count: |
|
log_message += ", count=%d" % _tf_print_counts[name] |
|
tf.logging.info("[%s]: %s" % (log_message, x)) |
|
if print_freq > 0: |
|
for i, x in enumerate(xs): |
|
_tf_print_running_sums[name][i] = 0 |
|
_tf_print_running_counts[name] = 0 |
|
return xs[0] |
|
|
|
print_op = tf.py_func(print_message, tensors, tensors[0].dtype) |
|
with tf.control_dependencies([print_op]): |
|
op = tf.identity(op) |
|
return op |
|
|
|
|
|
periodically = common.periodically |
|
def _periodically(body, period, name='periodically'): |
|
"""Periodically performs a tensorflow op.""" |
|
if period is None or period == 0: |
|
return tf.no_op() |
|
|
|
if period < 0: |
|
raise ValueError("period cannot be less than 0.") |
|
|
|
if period == 1: |
|
return body() |
|
|
|
with tf.variable_scope(None, default_name=name): |
|
counter = tf.get_variable( |
|
"counter", |
|
shape=[], |
|
dtype=tf.int64, |
|
trainable=False, |
|
initializer=tf.constant_initializer(period, dtype=tf.int64)) |
|
|
|
def _wrapped_body(): |
|
with tf.control_dependencies([body()]): |
|
return counter.assign(1) |
|
|
|
update = tf.cond( |
|
tf.equal(counter, period), _wrapped_body, |
|
lambda: counter.assign_add(1)) |
|
|
|
return update |
|
|
|
soft_variables_update = common.soft_variables_update |
|
|