|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
from collections import namedtuple |
|
import os |
|
import time |
|
|
|
import tensorflow as tf |
|
|
|
import gin.tf |
|
|
|
flags = tf.app.flags |
|
|
|
|
|
flags.DEFINE_multi_string('config_file', None, |
|
'List of paths to the config files.') |
|
flags.DEFINE_multi_string('params', None, |
|
'Newline separated list of Gin parameter bindings.') |
|
|
|
flags.DEFINE_string('train_dir', None, |
|
'Directory for writing logs/summaries during training.') |
|
flags.DEFINE_string('master', 'local', |
|
'BNS name of the TensorFlow master to use.') |
|
flags.DEFINE_integer('task', 0, 'task id') |
|
flags.DEFINE_integer('save_interval_secs', 300, 'The frequency at which ' |
|
'checkpoints are saved, in seconds.') |
|
flags.DEFINE_integer('save_summaries_secs', 30, 'The frequency at which ' |
|
'summaries are saved, in seconds.') |
|
flags.DEFINE_boolean('summarize_gradients', False, |
|
'Whether to generate gradient summaries.') |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
TrainOps = namedtuple('TrainOps', |
|
['train_op', 'meta_train_op', 'collect_experience_op']) |
|
|
|
|
|
class TrainStep(object): |
|
"""Handles training step.""" |
|
|
|
def __init__(self, |
|
max_number_of_steps=0, |
|
num_updates_per_observation=1, |
|
num_collect_per_update=1, |
|
num_collect_per_meta_update=1, |
|
log_every_n_steps=1, |
|
policy_save_fn=None, |
|
save_policy_every_n_steps=0, |
|
should_stop_early=None): |
|
"""Returns a function that is executed at each step of slim training. |
|
|
|
Args: |
|
max_number_of_steps: Optional maximum number of train steps to take. |
|
num_updates_per_observation: Number of updates per observation. |
|
log_every_n_steps: The frequency, in terms of global steps, that the loss |
|
and global step and logged. |
|
policy_save_fn: A tf.Saver().save function to save the policy. |
|
save_policy_every_n_steps: How frequently to save the policy. |
|
should_stop_early: Optional hook to report whether training should stop. |
|
Raises: |
|
ValueError: If policy_save_fn is not provided when |
|
save_policy_every_n_steps > 0. |
|
""" |
|
if save_policy_every_n_steps and policy_save_fn is None: |
|
raise ValueError( |
|
'policy_save_fn is required when save_policy_every_n_steps > 0') |
|
self.max_number_of_steps = max_number_of_steps |
|
self.num_updates_per_observation = num_updates_per_observation |
|
self.num_collect_per_update = num_collect_per_update |
|
self.num_collect_per_meta_update = num_collect_per_meta_update |
|
self.log_every_n_steps = log_every_n_steps |
|
self.policy_save_fn = policy_save_fn |
|
self.save_policy_every_n_steps = save_policy_every_n_steps |
|
self.should_stop_early = should_stop_early |
|
self.last_global_step_val = 0 |
|
self.train_op_fn = None |
|
self.collect_and_train_fn = None |
|
tf.logging.info('Training for %d max_number_of_steps', |
|
self.max_number_of_steps) |
|
|
|
def train_step(self, sess, train_ops, global_step, _): |
|
"""This function will be called at each step of training. |
|
|
|
This represents one step of the DDPG algorithm and can include: |
|
1. collect a <state, action, reward, next_state> transition |
|
2. update the target network |
|
3. train the actor |
|
4. train the critic |
|
|
|
Args: |
|
sess: A Tensorflow session. |
|
train_ops: A DdpgTrainOps tuple of train ops to run. |
|
global_step: The global step. |
|
|
|
Returns: |
|
A scalar total loss. |
|
A boolean should stop. |
|
""" |
|
start_time = time.time() |
|
if self.train_op_fn is None: |
|
self.train_op_fn = sess.make_callable([train_ops.train_op, global_step]) |
|
self.meta_train_op_fn = sess.make_callable([train_ops.meta_train_op, global_step]) |
|
self.collect_fn = sess.make_callable([train_ops.collect_experience_op, global_step]) |
|
self.collect_and_train_fn = sess.make_callable( |
|
[train_ops.train_op, global_step, train_ops.collect_experience_op]) |
|
self.collect_and_meta_train_fn = sess.make_callable( |
|
[train_ops.meta_train_op, global_step, train_ops.collect_experience_op]) |
|
for _ in range(self.num_collect_per_update - 1): |
|
self.collect_fn() |
|
for _ in range(self.num_updates_per_observation - 1): |
|
self.train_op_fn() |
|
|
|
total_loss, global_step_val, _ = self.collect_and_train_fn() |
|
if (global_step_val // self.num_collect_per_meta_update != |
|
self.last_global_step_val // self.num_collect_per_meta_update): |
|
self.meta_train_op_fn() |
|
|
|
time_elapsed = time.time() - start_time |
|
should_stop = False |
|
if self.max_number_of_steps: |
|
should_stop = global_step_val >= self.max_number_of_steps |
|
if global_step_val != self.last_global_step_val: |
|
if (self.save_policy_every_n_steps and |
|
global_step_val // self.save_policy_every_n_steps != |
|
self.last_global_step_val // self.save_policy_every_n_steps): |
|
self.policy_save_fn(sess) |
|
|
|
if (self.log_every_n_steps and |
|
global_step_val % self.log_every_n_steps == 0): |
|
tf.logging.info( |
|
'global step %d: loss = %.4f (%.3f sec/step) (%d steps/sec)', |
|
global_step_val, total_loss, time_elapsed, 1 / time_elapsed) |
|
|
|
self.last_global_step_val = global_step_val |
|
stop_early = bool(self.should_stop_early and self.should_stop_early()) |
|
return total_loss, should_stop or stop_early |
|
|
|
|
|
def create_counter_summaries(counters): |
|
"""Add named summaries to counters, a list of tuples (name, counter).""" |
|
if counters: |
|
with tf.name_scope('Counters/'): |
|
for name, counter in counters: |
|
tf.summary.scalar(name, counter) |
|
|
|
|
|
def gen_debug_batch_summaries(batch): |
|
"""Generates summaries for the sampled replay batch.""" |
|
states, actions, rewards, _, next_states = batch |
|
with tf.name_scope('batch'): |
|
for s in range(states.get_shape()[-1]): |
|
tf.summary.histogram('states_%d' % s, states[:, s]) |
|
for s in range(states.get_shape()[-1]): |
|
tf.summary.histogram('next_states_%d' % s, next_states[:, s]) |
|
for a in range(actions.get_shape()[-1]): |
|
tf.summary.histogram('actions_%d' % a, actions[:, a]) |
|
tf.summary.histogram('rewards', rewards) |
|
|