NCTCMumbai's picture
Upload 2583 files
97b6013 verified
# Copyright 2018 The TensorFlow Authors All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# ==============================================================================
"""TensorFlow utility functions.
"""
from __future__ import absolute_import
from __future__ import division
from __future__ import print_function
from copy import deepcopy
import tensorflow as tf
from tf_agents import specs
from tf_agents.utils import common
_tf_print_counts = dict()
_tf_print_running_sums = dict()
_tf_print_running_counts = dict()
_tf_print_ids = 0
def get_contextual_env_base(env_base, begin_ops=None, end_ops=None):
"""Wrap env_base with additional tf ops."""
# pylint: disable=protected-access
def init(self_, env_base):
self_._env_base = env_base
attribute_list = ["_render_mode", "_gym_env"]
for attribute in attribute_list:
if hasattr(env_base, attribute):
setattr(self_, attribute, getattr(env_base, attribute))
if hasattr(env_base, "physics"):
self_._physics = env_base.physics
elif hasattr(env_base, "gym"):
class Physics(object):
def render(self, *args, **kwargs):
return env_base.gym.render("rgb_array")
physics = Physics()
self_._physics = physics
self_.physics = physics
def set_sess(self_, sess):
self_._sess = sess
if hasattr(self_._env_base, "set_sess"):
self_._env_base.set_sess(sess)
def begin_episode(self_):
self_._env_base.reset()
if begin_ops is not None:
self_._sess.run(begin_ops)
def end_episode(self_):
self_._env_base.reset()
if end_ops is not None:
self_._sess.run(end_ops)
return type("ContextualEnvBase", (env_base.__class__,), dict(
__init__=init,
set_sess=set_sess,
begin_episode=begin_episode,
end_episode=end_episode,
))(env_base)
# pylint: enable=protected-access
def merge_specs(specs_):
"""Merge TensorSpecs.
Args:
specs_: List of TensorSpecs to be merged.
Returns:
a TensorSpec: a merged TensorSpec.
"""
shape = specs_[0].shape
dtype = specs_[0].dtype
name = specs_[0].name
for spec in specs_[1:]:
assert shape[1:] == spec.shape[1:], "incompatible shapes: %s, %s" % (
shape, spec.shape)
assert dtype == spec.dtype, "incompatible dtypes: %s, %s" % (
dtype, spec.dtype)
shape = merge_shapes((shape, spec.shape), axis=0)
return specs.TensorSpec(
shape=shape,
dtype=dtype,
name=name,
)
def merge_shapes(shapes, axis=0):
"""Merge TensorShapes.
Args:
shapes: List of TensorShapes to be merged.
axis: optional, the axis to merge shaped.
Returns:
a TensorShape: a merged TensorShape.
"""
assert len(shapes) > 1
dims = deepcopy(shapes[0].dims)
for shape in shapes[1:]:
assert shapes[0].ndims == shape.ndims
dims[axis] += shape.dims[axis]
return tf.TensorShape(dims=dims)
def get_all_vars(ignore_scopes=None):
"""Get all tf variables in scope.
Args:
ignore_scopes: A list of scope names to ignore.
Returns:
A list of all tf variables in scope.
"""
all_vars = tf.get_collection(tf.GraphKeys.GLOBAL_VARIABLES)
all_vars = [var for var in all_vars if ignore_scopes is None or not
any(var.name.startswith(scope) for scope in ignore_scopes)]
return all_vars
def clip(tensor, range_=None):
"""Return a tf op which clips tensor according to range_.
Args:
tensor: A Tensor to be clipped.
range_: None, or a tuple representing (minval, maxval)
Returns:
A clipped Tensor.
"""
if range_ is None:
return tf.identity(tensor)
elif isinstance(range_, (tuple, list)):
assert len(range_) == 2
return tf.clip_by_value(tensor, range_[0], range_[1])
else: raise NotImplementedError("Unacceptable range input: %r" % range_)
def clip_to_bounds(value, minimum, maximum):
"""Clips value to be between minimum and maximum.
Args:
value: (tensor) value to be clipped.
minimum: (numpy float array) minimum value to clip to.
maximum: (numpy float array) maximum value to clip to.
Returns:
clipped_value: (tensor) `value` clipped to between `minimum` and `maximum`.
"""
value = tf.minimum(value, maximum)
return tf.maximum(value, minimum)
clip_to_spec = common.clip_to_spec
def _clip_to_spec(value, spec):
"""Clips value to a given bounded tensor spec.
Args:
value: (tensor) value to be clipped.
spec: (BoundedTensorSpec) spec containing min. and max. values for clipping.
Returns:
clipped_value: (tensor) `value` clipped to be compatible with `spec`.
"""
return clip_to_bounds(value, spec.minimum, spec.maximum)
join_scope = common.join_scope
def _join_scope(parent_scope, child_scope):
"""Joins a parent and child scope using `/`, checking for empty/none.
Args:
parent_scope: (string) parent/prefix scope.
child_scope: (string) child/suffix scope.
Returns:
joined scope: (string) parent and child scopes joined by /.
"""
if not parent_scope:
return child_scope
if not child_scope:
return parent_scope
return '/'.join([parent_scope, child_scope])
def assign_vars(vars_, values):
"""Returns the update ops for assigning a list of vars.
Args:
vars_: A list of variables.
values: A list of tensors representing new values.
Returns:
A list of update ops for the variables.
"""
return [var.assign(value) for var, value in zip(vars_, values)]
def identity_vars(vars_):
"""Return the identity ops for a list of tensors.
Args:
vars_: A list of tensors.
Returns:
A list of identity ops.
"""
return [tf.identity(var) for var in vars_]
def tile(var, batch_size=1):
"""Return tiled tensor.
Args:
var: A tensor representing the state.
batch_size: Batch size.
Returns:
A tensor with shape [batch_size,] + var.shape.
"""
batch_var = tf.tile(
tf.expand_dims(var, 0),
(batch_size,) + (1,) * var.get_shape().ndims)
return batch_var
def batch_list(vars_list):
"""Batch a list of variables.
Args:
vars_list: A list of tensor variables.
Returns:
A list of tensor variables with additional first dimension.
"""
return [tf.expand_dims(var, 0) for var in vars_list]
def tf_print(op,
tensors,
message="",
first_n=-1,
name=None,
sub_messages=None,
print_freq=-1,
include_count=True):
"""tf.Print, but to stdout."""
# TODO(shanegu): `name` is deprecated. Remove from the rest of codes.
global _tf_print_ids
_tf_print_ids += 1
name = _tf_print_ids
_tf_print_counts[name] = 0
if print_freq > 0:
_tf_print_running_sums[name] = [0 for _ in tensors]
_tf_print_running_counts[name] = 0
def print_message(*xs):
"""print message fn."""
_tf_print_counts[name] += 1
if print_freq > 0:
for i, x in enumerate(xs):
_tf_print_running_sums[name][i] += x
_tf_print_running_counts[name] += 1
if (print_freq <= 0 or _tf_print_running_counts[name] >= print_freq) and (
first_n < 0 or _tf_print_counts[name] <= first_n):
for i, x in enumerate(xs):
if print_freq > 0:
del x
x = _tf_print_running_sums[name][i]/_tf_print_running_counts[name]
if sub_messages is None:
sub_message = str(i)
else:
sub_message = sub_messages[i]
log_message = "%s, %s" % (message, sub_message)
if include_count:
log_message += ", count=%d" % _tf_print_counts[name]
tf.logging.info("[%s]: %s" % (log_message, x))
if print_freq > 0:
for i, x in enumerate(xs):
_tf_print_running_sums[name][i] = 0
_tf_print_running_counts[name] = 0
return xs[0]
print_op = tf.py_func(print_message, tensors, tensors[0].dtype)
with tf.control_dependencies([print_op]):
op = tf.identity(op)
return op
periodically = common.periodically
def _periodically(body, period, name='periodically'):
"""Periodically performs a tensorflow op."""
if period is None or period == 0:
return tf.no_op()
if period < 0:
raise ValueError("period cannot be less than 0.")
if period == 1:
return body()
with tf.variable_scope(None, default_name=name):
counter = tf.get_variable(
"counter",
shape=[],
dtype=tf.int64,
trainable=False,
initializer=tf.constant_initializer(period, dtype=tf.int64))
def _wrapped_body():
with tf.control_dependencies([body()]):
return counter.assign(1)
update = tf.cond(
tf.equal(counter, period), _wrapped_body,
lambda: counter.assign_add(1))
return update
soft_variables_update = common.soft_variables_update