|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A DDPG/NAF agent. |
|
|
|
Implements the Deep Deterministic Policy Gradient (DDPG) algorithm from |
|
"Continuous control with deep reinforcement learning" - Lilicrap et al. |
|
https://arxiv.org/abs/1509.02971, and the Normalized Advantage Functions (NAF) |
|
algorithm "Continuous Deep Q-Learning with Model-based Acceleration" - Gu et al. |
|
https://arxiv.org/pdf/1603.00748. |
|
""" |
|
|
|
import tensorflow as tf |
|
slim = tf.contrib.slim |
|
import gin.tf |
|
from utils import utils |
|
from agents import ddpg_networks as networks |
|
|
|
|
|
@gin.configurable |
|
class DdpgAgent(object): |
|
"""An RL agent that learns using the DDPG algorithm. |
|
|
|
Example usage: |
|
|
|
def critic_net(states, actions): |
|
... |
|
def actor_net(states, num_action_dims): |
|
... |
|
|
|
Given a tensorflow environment tf_env, |
|
(of type learning.deepmind.rl.environments.tensorflow.python.tfpyenvironment) |
|
|
|
obs_spec = tf_env.observation_spec() |
|
action_spec = tf_env.action_spec() |
|
|
|
ddpg_agent = agent.DdpgAgent(obs_spec, |
|
action_spec, |
|
actor_net=actor_net, |
|
critic_net=critic_net) |
|
|
|
we can perform actions on the environment as follows: |
|
|
|
state = tf_env.observations()[0] |
|
action = ddpg_agent.actor_net(tf.expand_dims(state, 0))[0, :] |
|
transition_type, reward, discount = tf_env.step([action]) |
|
|
|
Train: |
|
|
|
critic_loss = ddpg_agent.critic_loss(states, actions, rewards, discounts, |
|
next_states) |
|
actor_loss = ddpg_agent.actor_loss(states) |
|
|
|
critic_train_op = slim.learning.create_train_op( |
|
critic_loss, |
|
critic_optimizer, |
|
variables_to_train=ddpg_agent.get_trainable_critic_vars(), |
|
) |
|
|
|
actor_train_op = slim.learning.create_train_op( |
|
actor_loss, |
|
actor_optimizer, |
|
variables_to_train=ddpg_agent.get_trainable_actor_vars(), |
|
) |
|
""" |
|
|
|
ACTOR_NET_SCOPE = 'actor_net' |
|
CRITIC_NET_SCOPE = 'critic_net' |
|
TARGET_ACTOR_NET_SCOPE = 'target_actor_net' |
|
TARGET_CRITIC_NET_SCOPE = 'target_critic_net' |
|
|
|
def __init__(self, |
|
observation_spec, |
|
action_spec, |
|
actor_net=networks.actor_net, |
|
critic_net=networks.critic_net, |
|
td_errors_loss=tf.losses.huber_loss, |
|
dqda_clipping=0., |
|
actions_regularizer=0., |
|
target_q_clipping=None, |
|
residual_phi=0.0, |
|
debug_summaries=False): |
|
"""Constructs a DDPG agent. |
|
|
|
Args: |
|
observation_spec: A TensorSpec defining the observations. |
|
action_spec: A BoundedTensorSpec defining the actions. |
|
actor_net: A callable that creates the actor network. Must take the |
|
following arguments: states, num_actions. Please see networks.actor_net |
|
for an example. |
|
critic_net: A callable that creates the critic network. Must take the |
|
following arguments: states, actions. Please see networks.critic_net |
|
for an example. |
|
td_errors_loss: A callable defining the loss function for the critic |
|
td error. |
|
dqda_clipping: (float) clips the gradient dqda element-wise between |
|
[-dqda_clipping, dqda_clipping]. Does not perform clipping if |
|
dqda_clipping == 0. |
|
actions_regularizer: A scalar, when positive penalizes the norm of the |
|
actions. This can prevent saturation of actions for the actor_loss. |
|
target_q_clipping: (tuple of floats) clips target q values within |
|
(low, high) values when computing the critic loss. |
|
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that |
|
interpolates between Q-learning and residual gradient algorithm. |
|
http://www.leemon.com/papers/1995b.pdf |
|
debug_summaries: If True, add summaries to help debug behavior. |
|
Raises: |
|
ValueError: If 'dqda_clipping' is < 0. |
|
""" |
|
self._observation_spec = observation_spec[0] |
|
self._action_spec = action_spec[0] |
|
self._state_shape = tf.TensorShape([None]).concatenate( |
|
self._observation_spec.shape) |
|
self._action_shape = tf.TensorShape([None]).concatenate( |
|
self._action_spec.shape) |
|
self._num_action_dims = self._action_spec.shape.num_elements() |
|
|
|
self._scope = tf.get_variable_scope().name |
|
self._actor_net = tf.make_template( |
|
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True) |
|
self._critic_net = tf.make_template( |
|
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True) |
|
self._target_actor_net = tf.make_template( |
|
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True) |
|
self._target_critic_net = tf.make_template( |
|
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True) |
|
self._td_errors_loss = td_errors_loss |
|
if dqda_clipping < 0: |
|
raise ValueError('dqda_clipping must be >= 0.') |
|
self._dqda_clipping = dqda_clipping |
|
self._actions_regularizer = actions_regularizer |
|
self._target_q_clipping = target_q_clipping |
|
self._residual_phi = residual_phi |
|
self._debug_summaries = debug_summaries |
|
|
|
def _batch_state(self, state): |
|
"""Convert state to a batched state. |
|
|
|
Args: |
|
state: Either a list/tuple with an state tensor [num_state_dims]. |
|
Returns: |
|
A tensor [1, num_state_dims] |
|
""" |
|
if isinstance(state, (tuple, list)): |
|
state = state[0] |
|
if state.get_shape().ndims == 1: |
|
state = tf.expand_dims(state, 0) |
|
return state |
|
|
|
def action(self, state): |
|
"""Returns the next action for the state. |
|
|
|
Args: |
|
state: A [num_state_dims] tensor representing a state. |
|
Returns: |
|
A [num_action_dims] tensor representing the action. |
|
""" |
|
return self.actor_net(self._batch_state(state), stop_gradients=True)[0, :] |
|
|
|
@gin.configurable('ddpg_sample_action') |
|
def sample_action(self, state, stddev=1.0): |
|
"""Returns the action for the state with additive noise. |
|
|
|
Args: |
|
state: A [num_state_dims] tensor representing a state. |
|
stddev: stddev for the Ornstein-Uhlenbeck noise. |
|
Returns: |
|
A [num_action_dims] action tensor. |
|
""" |
|
agent_action = self.action(state) |
|
agent_action += tf.random_normal(tf.shape(agent_action)) * stddev |
|
return utils.clip_to_spec(agent_action, self._action_spec) |
|
|
|
def actor_net(self, states, stop_gradients=False): |
|
"""Returns the output of the actor network. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
stop_gradients: (boolean) if true, gradients cannot be propogated through |
|
this operation. |
|
Returns: |
|
A [batch_size, num_action_dims] tensor of actions. |
|
Raises: |
|
ValueError: If `states` does not have the expected dimensions. |
|
""" |
|
self._validate_states(states) |
|
actions = self._actor_net(states, self._action_spec) |
|
if stop_gradients: |
|
actions = tf.stop_gradient(actions) |
|
return actions |
|
|
|
def critic_net(self, states, actions, for_critic_loss=False): |
|
"""Returns the output of the critic network. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
actions: A [batch_size, num_action_dims] tensor representing a batch |
|
of actions. |
|
Returns: |
|
q values: A [batch_size] tensor of q values. |
|
Raises: |
|
ValueError: If `states` or `actions' do not have the expected dimensions. |
|
""" |
|
self._validate_states(states) |
|
self._validate_actions(actions) |
|
return self._critic_net(states, actions, |
|
for_critic_loss=for_critic_loss) |
|
|
|
def target_actor_net(self, states): |
|
"""Returns the output of the target actor network. |
|
|
|
The target network is used to compute stable targets for training. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
Returns: |
|
A [batch_size, num_action_dims] tensor of actions. |
|
Raises: |
|
ValueError: If `states` does not have the expected dimensions. |
|
""" |
|
self._validate_states(states) |
|
actions = self._target_actor_net(states, self._action_spec) |
|
return tf.stop_gradient(actions) |
|
|
|
def target_critic_net(self, states, actions, for_critic_loss=False): |
|
"""Returns the output of the target critic network. |
|
|
|
The target network is used to compute stable targets for training. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
actions: A [batch_size, num_action_dims] tensor representing a batch |
|
of actions. |
|
Returns: |
|
q values: A [batch_size] tensor of q values. |
|
Raises: |
|
ValueError: If `states` or `actions' do not have the expected dimensions. |
|
""" |
|
self._validate_states(states) |
|
self._validate_actions(actions) |
|
return tf.stop_gradient( |
|
self._target_critic_net(states, actions, |
|
for_critic_loss=for_critic_loss)) |
|
|
|
def value_net(self, states, for_critic_loss=False): |
|
"""Returns the output of the critic evaluated with the actor. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
Returns: |
|
q values: A [batch_size] tensor of q values. |
|
""" |
|
actions = self.actor_net(states) |
|
return self.critic_net(states, actions, |
|
for_critic_loss=for_critic_loss) |
|
|
|
def target_value_net(self, states, for_critic_loss=False): |
|
"""Returns the output of the target critic evaluated with the target actor. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
Returns: |
|
q values: A [batch_size] tensor of q values. |
|
""" |
|
target_actions = self.target_actor_net(states) |
|
return self.target_critic_net(states, target_actions, |
|
for_critic_loss=for_critic_loss) |
|
|
|
def critic_loss(self, states, actions, rewards, discounts, |
|
next_states): |
|
"""Computes a loss for training the critic network. |
|
|
|
The loss is the mean squared error between the Q value predictions of the |
|
critic and Q values estimated using TD-lambda. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
actions: A [batch_size, num_action_dims] tensor representing a batch |
|
of actions. |
|
rewards: A [batch_size, ...] tensor representing a batch of rewards, |
|
broadcastable to the critic net output. |
|
discounts: A [batch_size, ...] tensor representing a batch of discounts, |
|
broadcastable to the critic net output. |
|
next_states: A [batch_size, num_state_dims] tensor representing a batch |
|
of next states. |
|
Returns: |
|
A rank-0 tensor representing the critic loss. |
|
Raises: |
|
ValueError: If any of the inputs do not have the expected dimensions, or |
|
if their batch_sizes do not match. |
|
""" |
|
self._validate_states(states) |
|
self._validate_actions(actions) |
|
self._validate_states(next_states) |
|
|
|
target_q_values = self.target_value_net(next_states, for_critic_loss=True) |
|
td_targets = target_q_values * discounts + rewards |
|
if self._target_q_clipping is not None: |
|
td_targets = tf.clip_by_value(td_targets, self._target_q_clipping[0], |
|
self._target_q_clipping[1]) |
|
q_values = self.critic_net(states, actions, for_critic_loss=True) |
|
td_errors = td_targets - q_values |
|
if self._debug_summaries: |
|
gen_debug_td_error_summaries( |
|
target_q_values, q_values, td_targets, td_errors) |
|
|
|
loss = self._td_errors_loss(td_targets, q_values) |
|
|
|
if self._residual_phi > 0.0: |
|
residual_q_values = self.value_net(next_states, for_critic_loss=True) |
|
residual_td_targets = residual_q_values * discounts + rewards |
|
if self._target_q_clipping is not None: |
|
residual_td_targets = tf.clip_by_value(residual_td_targets, |
|
self._target_q_clipping[0], |
|
self._target_q_clipping[1]) |
|
residual_td_errors = residual_td_targets - q_values |
|
residual_loss = self._td_errors_loss( |
|
residual_td_targets, residual_q_values) |
|
loss = (loss * (1.0 - self._residual_phi) + |
|
residual_loss * self._residual_phi) |
|
return loss |
|
|
|
def actor_loss(self, states): |
|
"""Computes a loss for training the actor network. |
|
|
|
Note that output does not represent an actual loss. It is called a loss only |
|
in the sense that its gradient w.r.t. the actor network weights is the |
|
correct gradient for training the actor network, |
|
i.e. dloss/dweights = (dq/da)*(da/dweights) |
|
which is the gradient used in Algorithm 1 of Lilicrap et al. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
Returns: |
|
A rank-0 tensor representing the actor loss. |
|
Raises: |
|
ValueError: If `states` does not have the expected dimensions. |
|
""" |
|
self._validate_states(states) |
|
actions = self.actor_net(states, stop_gradients=False) |
|
critic_values = self.critic_net(states, actions) |
|
q_values = self.critic_function(critic_values, states) |
|
dqda = tf.gradients([q_values], [actions])[0] |
|
dqda_unclipped = dqda |
|
if self._dqda_clipping > 0: |
|
dqda = tf.clip_by_value(dqda, -self._dqda_clipping, self._dqda_clipping) |
|
|
|
actions_norm = tf.norm(actions) |
|
if self._debug_summaries: |
|
with tf.name_scope('dqda'): |
|
tf.summary.scalar('actions_norm', actions_norm) |
|
tf.summary.histogram('dqda', dqda) |
|
tf.summary.histogram('dqda_unclipped', dqda_unclipped) |
|
tf.summary.histogram('actions', actions) |
|
for a in range(self._num_action_dims): |
|
tf.summary.histogram('dqda_unclipped_%d' % a, dqda_unclipped[:, a]) |
|
tf.summary.histogram('dqda_%d' % a, dqda[:, a]) |
|
|
|
actions_norm *= self._actions_regularizer |
|
return slim.losses.mean_squared_error(tf.stop_gradient(dqda + actions), |
|
actions, |
|
scope='actor_loss') + actions_norm |
|
|
|
@gin.configurable('ddpg_critic_function') |
|
def critic_function(self, critic_values, states, weights=None): |
|
"""Computes q values based on critic_net outputs, states, and weights. |
|
|
|
Args: |
|
critic_values: A tf.float32 [batch_size, ...] tensor representing outputs |
|
from the critic net. |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
weights: A list or Numpy array or tensor with a shape broadcastable to |
|
`critic_values`. |
|
Returns: |
|
A tf.float32 [batch_size] tensor representing q values. |
|
""" |
|
del states |
|
if weights is not None: |
|
weights = tf.convert_to_tensor(weights, dtype=critic_values.dtype) |
|
critic_values *= weights |
|
if critic_values.shape.ndims > 1: |
|
critic_values = tf.reduce_sum(critic_values, |
|
range(1, critic_values.shape.ndims)) |
|
critic_values.shape.assert_has_rank(1) |
|
return critic_values |
|
|
|
@gin.configurable('ddpg_update_targets') |
|
def update_targets(self, tau=1.0): |
|
"""Performs a soft update of the target network parameters. |
|
|
|
For each weight w_s in the actor/critic networks, and its corresponding |
|
weight w_t in the target actor/critic networks, a soft update is: |
|
w_t = (1- tau) x w_t + tau x ws |
|
|
|
Args: |
|
tau: A float scalar in [0, 1] |
|
Returns: |
|
An operation that performs a soft update of the target network parameters. |
|
Raises: |
|
ValueError: If `tau` is not in [0, 1]. |
|
""" |
|
if tau < 0 or tau > 1: |
|
raise ValueError('Input `tau` should be in [0, 1].') |
|
update_actor = utils.soft_variables_update( |
|
slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)), |
|
slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)), |
|
tau) |
|
update_critic = utils.soft_variables_update( |
|
slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)), |
|
slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)), |
|
tau) |
|
return tf.group(update_actor, update_critic, name='update_targets') |
|
|
|
def get_trainable_critic_vars(self): |
|
"""Returns a list of trainable variables in the critic network. |
|
|
|
Returns: |
|
A list of trainable variables in the critic network. |
|
""" |
|
return slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)) |
|
|
|
def get_trainable_actor_vars(self): |
|
"""Returns a list of trainable variables in the actor network. |
|
|
|
Returns: |
|
A list of trainable variables in the actor network. |
|
""" |
|
return slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)) |
|
|
|
def get_critic_vars(self): |
|
"""Returns a list of all variables in the critic network. |
|
|
|
Returns: |
|
A list of trainable variables in the critic network. |
|
""" |
|
return slim.get_model_variables( |
|
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)) |
|
|
|
def get_actor_vars(self): |
|
"""Returns a list of all variables in the actor network. |
|
|
|
Returns: |
|
A list of trainable variables in the actor network. |
|
""" |
|
return slim.get_model_variables( |
|
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)) |
|
|
|
def _validate_states(self, states): |
|
"""Raises a value error if `states` does not have the expected shape. |
|
|
|
Args: |
|
states: A tensor. |
|
Raises: |
|
ValueError: If states.shape or states.dtype are not compatible with |
|
observation_spec. |
|
""" |
|
states.shape.assert_is_compatible_with(self._state_shape) |
|
if not states.dtype.is_compatible_with(self._observation_spec.dtype): |
|
raise ValueError('states.dtype={} is not compatible with' |
|
' observation_spec.dtype={}'.format( |
|
states.dtype, self._observation_spec.dtype)) |
|
|
|
def _validate_actions(self, actions): |
|
"""Raises a value error if `actions` does not have the expected shape. |
|
|
|
Args: |
|
actions: A tensor. |
|
Raises: |
|
ValueError: If actions.shape or actions.dtype are not compatible with |
|
action_spec. |
|
""" |
|
actions.shape.assert_is_compatible_with(self._action_shape) |
|
if not actions.dtype.is_compatible_with(self._action_spec.dtype): |
|
raise ValueError('actions.dtype={} is not compatible with' |
|
' action_spec.dtype={}'.format( |
|
actions.dtype, self._action_spec.dtype)) |
|
|
|
|
|
@gin.configurable |
|
class TD3Agent(DdpgAgent): |
|
"""An RL agent that learns using the TD3 algorithm.""" |
|
|
|
ACTOR_NET_SCOPE = 'actor_net' |
|
CRITIC_NET_SCOPE = 'critic_net' |
|
CRITIC_NET2_SCOPE = 'critic_net2' |
|
TARGET_ACTOR_NET_SCOPE = 'target_actor_net' |
|
TARGET_CRITIC_NET_SCOPE = 'target_critic_net' |
|
TARGET_CRITIC_NET2_SCOPE = 'target_critic_net2' |
|
|
|
def __init__(self, |
|
observation_spec, |
|
action_spec, |
|
actor_net=networks.actor_net, |
|
critic_net=networks.critic_net, |
|
td_errors_loss=tf.losses.huber_loss, |
|
dqda_clipping=0., |
|
actions_regularizer=0., |
|
target_q_clipping=None, |
|
residual_phi=0.0, |
|
debug_summaries=False): |
|
"""Constructs a TD3 agent. |
|
|
|
Args: |
|
observation_spec: A TensorSpec defining the observations. |
|
action_spec: A BoundedTensorSpec defining the actions. |
|
actor_net: A callable that creates the actor network. Must take the |
|
following arguments: states, num_actions. Please see networks.actor_net |
|
for an example. |
|
critic_net: A callable that creates the critic network. Must take the |
|
following arguments: states, actions. Please see networks.critic_net |
|
for an example. |
|
td_errors_loss: A callable defining the loss function for the critic |
|
td error. |
|
dqda_clipping: (float) clips the gradient dqda element-wise between |
|
[-dqda_clipping, dqda_clipping]. Does not perform clipping if |
|
dqda_clipping == 0. |
|
actions_regularizer: A scalar, when positive penalizes the norm of the |
|
actions. This can prevent saturation of actions for the actor_loss. |
|
target_q_clipping: (tuple of floats) clips target q values within |
|
(low, high) values when computing the critic loss. |
|
residual_phi: (float) [0.0, 1.0] Residual algorithm parameter that |
|
interpolates between Q-learning and residual gradient algorithm. |
|
http://www.leemon.com/papers/1995b.pdf |
|
debug_summaries: If True, add summaries to help debug behavior. |
|
Raises: |
|
ValueError: If 'dqda_clipping' is < 0. |
|
""" |
|
self._observation_spec = observation_spec[0] |
|
self._action_spec = action_spec[0] |
|
self._state_shape = tf.TensorShape([None]).concatenate( |
|
self._observation_spec.shape) |
|
self._action_shape = tf.TensorShape([None]).concatenate( |
|
self._action_spec.shape) |
|
self._num_action_dims = self._action_spec.shape.num_elements() |
|
|
|
self._scope = tf.get_variable_scope().name |
|
self._actor_net = tf.make_template( |
|
self.ACTOR_NET_SCOPE, actor_net, create_scope_now_=True) |
|
self._critic_net = tf.make_template( |
|
self.CRITIC_NET_SCOPE, critic_net, create_scope_now_=True) |
|
self._critic_net2 = tf.make_template( |
|
self.CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True) |
|
self._target_actor_net = tf.make_template( |
|
self.TARGET_ACTOR_NET_SCOPE, actor_net, create_scope_now_=True) |
|
self._target_critic_net = tf.make_template( |
|
self.TARGET_CRITIC_NET_SCOPE, critic_net, create_scope_now_=True) |
|
self._target_critic_net2 = tf.make_template( |
|
self.TARGET_CRITIC_NET2_SCOPE, critic_net, create_scope_now_=True) |
|
self._td_errors_loss = td_errors_loss |
|
if dqda_clipping < 0: |
|
raise ValueError('dqda_clipping must be >= 0.') |
|
self._dqda_clipping = dqda_clipping |
|
self._actions_regularizer = actions_regularizer |
|
self._target_q_clipping = target_q_clipping |
|
self._residual_phi = residual_phi |
|
self._debug_summaries = debug_summaries |
|
|
|
def get_trainable_critic_vars(self): |
|
"""Returns a list of trainable variables in the critic network. |
|
NOTE: This gets the vars of both critic networks. |
|
|
|
Returns: |
|
A list of trainable variables in the critic network. |
|
""" |
|
return ( |
|
slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE))) |
|
|
|
def critic_net(self, states, actions, for_critic_loss=False): |
|
"""Returns the output of the critic network. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
actions: A [batch_size, num_action_dims] tensor representing a batch |
|
of actions. |
|
Returns: |
|
q values: A [batch_size] tensor of q values. |
|
Raises: |
|
ValueError: If `states` or `actions' do not have the expected dimensions. |
|
""" |
|
values1 = self._critic_net(states, actions, |
|
for_critic_loss=for_critic_loss) |
|
values2 = self._critic_net2(states, actions, |
|
for_critic_loss=for_critic_loss) |
|
if for_critic_loss: |
|
return values1, values2 |
|
return values1 |
|
|
|
def target_critic_net(self, states, actions, for_critic_loss=False): |
|
"""Returns the output of the target critic network. |
|
|
|
The target network is used to compute stable targets for training. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
actions: A [batch_size, num_action_dims] tensor representing a batch |
|
of actions. |
|
Returns: |
|
q values: A [batch_size] tensor of q values. |
|
Raises: |
|
ValueError: If `states` or `actions' do not have the expected dimensions. |
|
""" |
|
self._validate_states(states) |
|
self._validate_actions(actions) |
|
values1 = tf.stop_gradient( |
|
self._target_critic_net(states, actions, |
|
for_critic_loss=for_critic_loss)) |
|
values2 = tf.stop_gradient( |
|
self._target_critic_net2(states, actions, |
|
for_critic_loss=for_critic_loss)) |
|
if for_critic_loss: |
|
return values1, values2 |
|
return values1 |
|
|
|
def value_net(self, states, for_critic_loss=False): |
|
"""Returns the output of the critic evaluated with the actor. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
Returns: |
|
q values: A [batch_size] tensor of q values. |
|
""" |
|
actions = self.actor_net(states) |
|
return self.critic_net(states, actions, |
|
for_critic_loss=for_critic_loss) |
|
|
|
def target_value_net(self, states, for_critic_loss=False): |
|
"""Returns the output of the target critic evaluated with the target actor. |
|
|
|
Args: |
|
states: A [batch_size, num_state_dims] tensor representing a batch |
|
of states. |
|
Returns: |
|
q values: A [batch_size] tensor of q values. |
|
""" |
|
target_actions = self.target_actor_net(states) |
|
noise = tf.clip_by_value( |
|
tf.random_normal(tf.shape(target_actions), stddev=0.2), -0.5, 0.5) |
|
values1, values2 = self.target_critic_net( |
|
states, target_actions + noise, |
|
for_critic_loss=for_critic_loss) |
|
values = tf.minimum(values1, values2) |
|
return values, values |
|
|
|
@gin.configurable('td3_update_targets') |
|
def update_targets(self, tau=1.0): |
|
"""Performs a soft update of the target network parameters. |
|
|
|
For each weight w_s in the actor/critic networks, and its corresponding |
|
weight w_t in the target actor/critic networks, a soft update is: |
|
w_t = (1- tau) x w_t + tau x ws |
|
|
|
Args: |
|
tau: A float scalar in [0, 1] |
|
Returns: |
|
An operation that performs a soft update of the target network parameters. |
|
Raises: |
|
ValueError: If `tau` is not in [0, 1]. |
|
""" |
|
if tau < 0 or tau > 1: |
|
raise ValueError('Input `tau` should be in [0, 1].') |
|
update_actor = utils.soft_variables_update( |
|
slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.ACTOR_NET_SCOPE)), |
|
slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.TARGET_ACTOR_NET_SCOPE)), |
|
tau) |
|
|
|
update_critic = utils.soft_variables_update( |
|
slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.CRITIC_NET_SCOPE)), |
|
slim.get_trainable_variables( |
|
utils.join_scope(self._scope, self.TARGET_CRITIC_NET_SCOPE)), |
|
tau) |
|
return tf.group(update_actor, update_critic, name='update_targets') |
|
|
|
|
|
def gen_debug_td_error_summaries( |
|
target_q_values, q_values, td_targets, td_errors): |
|
"""Generates debug summaries for critic given a set of batch samples. |
|
|
|
Args: |
|
target_q_values: set of predicted next stage values. |
|
q_values: current predicted value for the critic network. |
|
td_targets: discounted target_q_values with added next stage reward. |
|
td_errors: the different between td_targets and q_values. |
|
""" |
|
with tf.name_scope('td_errors'): |
|
tf.summary.histogram('td_targets', td_targets) |
|
tf.summary.histogram('q_values', q_values) |
|
tf.summary.histogram('target_q_values', target_q_values) |
|
tf.summary.histogram('td_errors', td_errors) |
|
with tf.name_scope('td_targets'): |
|
tf.summary.scalar('mean', tf.reduce_mean(td_targets)) |
|
tf.summary.scalar('max', tf.reduce_max(td_targets)) |
|
tf.summary.scalar('min', tf.reduce_min(td_targets)) |
|
with tf.name_scope('q_values'): |
|
tf.summary.scalar('mean', tf.reduce_mean(q_values)) |
|
tf.summary.scalar('max', tf.reduce_max(q_values)) |
|
tf.summary.scalar('min', tf.reduce_min(q_values)) |
|
with tf.name_scope('target_q_values'): |
|
tf.summary.scalar('mean', tf.reduce_mean(target_q_values)) |
|
tf.summary.scalar('max', tf.reduce_max(target_q_values)) |
|
tf.summary.scalar('min', tf.reduce_min(target_q_values)) |
|
with tf.name_scope('td_errors'): |
|
tf.summary.scalar('mean', tf.reduce_mean(td_errors)) |
|
tf.summary.scalar('max', tf.reduce_max(td_errors)) |
|
tf.summary.scalar('min', tf.reduce_min(td_errors)) |
|
tf.summary.scalar('mean_abs', tf.reduce_mean(tf.abs(td_errors))) |
|
|