|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Trainer for coordinating single or multi-replica training. |
|
|
|
Main point of entry for running models. Specifies most of |
|
the parameters used by different algorithms. |
|
""" |
|
|
|
import tensorflow as tf |
|
import numpy as np |
|
import random |
|
import os |
|
import pickle |
|
|
|
from six.moves import xrange |
|
import controller |
|
import model |
|
import policy |
|
import baseline |
|
import objective |
|
import full_episode_objective |
|
import trust_region |
|
import optimizers |
|
import replay_buffer |
|
import expert_paths |
|
import gym_wrapper |
|
import env_spec |
|
|
|
app = tf.app |
|
flags = tf.flags |
|
logging = tf.logging |
|
gfile = tf.gfile |
|
|
|
FLAGS = flags.FLAGS |
|
|
|
flags.DEFINE_string('env', 'Copy-v0', 'environment name') |
|
flags.DEFINE_integer('batch_size', 100, 'batch size') |
|
flags.DEFINE_integer('replay_batch_size', None, 'replay batch size; defaults to batch_size') |
|
flags.DEFINE_integer('num_samples', 1, |
|
'number of samples from each random seed initialization') |
|
flags.DEFINE_integer('max_step', 200, 'max number of steps to train on') |
|
flags.DEFINE_integer('cutoff_agent', 0, |
|
'number of steps at which to cut-off agent. ' |
|
'Defaults to always cutoff') |
|
flags.DEFINE_integer('num_steps', 100000, 'number of training steps') |
|
flags.DEFINE_integer('validation_frequency', 100, |
|
'every so many steps, output some stats') |
|
|
|
flags.DEFINE_float('target_network_lag', 0.95, |
|
'This exponential decay on online network yields target ' |
|
'network') |
|
flags.DEFINE_string('sample_from', 'online', |
|
'Sample actions from "online" network or "target" network') |
|
|
|
flags.DEFINE_string('objective', 'pcl', |
|
'pcl/upcl/a3c/trpo/reinforce/urex') |
|
flags.DEFINE_bool('trust_region_p', False, |
|
'use trust region for policy optimization') |
|
flags.DEFINE_string('value_opt', None, |
|
'leave as None to optimize it along with policy ' |
|
'(using critic_weight). Otherwise set to ' |
|
'"best_fit" (least squares regression), "lbfgs", or "grad"') |
|
flags.DEFINE_float('max_divergence', 0.01, |
|
'max divergence (i.e. KL) to allow during ' |
|
'trust region optimization') |
|
|
|
flags.DEFINE_float('learning_rate', 0.01, 'learning rate') |
|
flags.DEFINE_float('clip_norm', 5.0, 'clip norm') |
|
flags.DEFINE_float('clip_adv', 0.0, 'Clip advantages at this value. ' |
|
'Leave as 0 to not clip at all.') |
|
flags.DEFINE_float('critic_weight', 0.1, 'critic weight') |
|
flags.DEFINE_float('tau', 0.1, 'entropy regularizer.' |
|
'If using decaying tau, this is the final value.') |
|
flags.DEFINE_float('tau_decay', None, |
|
'decay tau by this much every 100 steps') |
|
flags.DEFINE_float('tau_start', 0.1, |
|
'start tau at this value') |
|
flags.DEFINE_float('eps_lambda', 0.0, 'relative entropy regularizer.') |
|
flags.DEFINE_bool('update_eps_lambda', False, |
|
'Update lambda automatically based on last 100 episodes.') |
|
flags.DEFINE_float('gamma', 1.0, 'discount') |
|
flags.DEFINE_integer('rollout', 10, 'rollout') |
|
flags.DEFINE_bool('use_target_values', False, |
|
'use target network for value estimates') |
|
flags.DEFINE_bool('fixed_std', True, |
|
'fix the std in Gaussian distributions') |
|
flags.DEFINE_bool('input_prev_actions', True, |
|
'input previous actions to policy network') |
|
flags.DEFINE_bool('recurrent', True, |
|
'use recurrent connections') |
|
flags.DEFINE_bool('input_time_step', False, |
|
'input time step into value calucations') |
|
|
|
flags.DEFINE_bool('use_online_batch', True, 'train on batches as they are sampled') |
|
flags.DEFINE_bool('batch_by_steps', False, |
|
'ensure each training batch has batch_size * max_step steps') |
|
flags.DEFINE_bool('unify_episodes', False, |
|
'Make sure replay buffer holds entire episodes, ' |
|
'even across distinct sampling steps') |
|
flags.DEFINE_integer('replay_buffer_size', 5000, 'replay buffer size') |
|
flags.DEFINE_float('replay_buffer_alpha', 0.5, 'replay buffer alpha param') |
|
flags.DEFINE_integer('replay_buffer_freq', 0, |
|
'replay buffer frequency (only supports -1/0/1)') |
|
flags.DEFINE_string('eviction', 'rand', |
|
'how to evict from replay buffer: rand/rank/fifo') |
|
flags.DEFINE_string('prioritize_by', 'rewards', |
|
'Prioritize replay buffer by "rewards" or "step"') |
|
flags.DEFINE_integer('num_expert_paths', 0, |
|
'number of expert paths to seed replay buffer with') |
|
|
|
flags.DEFINE_integer('internal_dim', 256, 'RNN internal dim') |
|
flags.DEFINE_integer('value_hidden_layers', 0, |
|
'number of hidden layers in value estimate') |
|
flags.DEFINE_integer('tf_seed', 42, 'random seed for tensorflow') |
|
|
|
flags.DEFINE_string('save_trajectories_dir', None, |
|
'directory to save trajectories to, if desired') |
|
flags.DEFINE_string('load_trajectories_file', None, |
|
'file to load expert trajectories from') |
|
|
|
|
|
flags.DEFINE_bool('supervisor', False, 'use supervisor training') |
|
flags.DEFINE_integer('task_id', 0, 'task id') |
|
flags.DEFINE_integer('ps_tasks', 0, 'number of ps tasks') |
|
flags.DEFINE_integer('num_replicas', 1, 'number of replicas used') |
|
flags.DEFINE_string('master', 'local', 'name of master') |
|
flags.DEFINE_string('save_dir', '', 'directory to save model to') |
|
flags.DEFINE_string('load_path', '', 'path of saved model to load (if none in save_dir)') |
|
|
|
|
|
class Trainer(object): |
|
"""Coordinates single or multi-replica training.""" |
|
|
|
def __init__(self): |
|
self.batch_size = FLAGS.batch_size |
|
self.replay_batch_size = FLAGS.replay_batch_size |
|
if self.replay_batch_size is None: |
|
self.replay_batch_size = self.batch_size |
|
self.num_samples = FLAGS.num_samples |
|
|
|
self.env_str = FLAGS.env |
|
self.env = gym_wrapper.GymWrapper(self.env_str, |
|
distinct=FLAGS.batch_size // self.num_samples, |
|
count=self.num_samples) |
|
self.eval_env = gym_wrapper.GymWrapper( |
|
self.env_str, |
|
distinct=FLAGS.batch_size // self.num_samples, |
|
count=self.num_samples) |
|
self.env_spec = env_spec.EnvSpec(self.env.get_one()) |
|
|
|
self.max_step = FLAGS.max_step |
|
self.cutoff_agent = FLAGS.cutoff_agent |
|
self.num_steps = FLAGS.num_steps |
|
self.validation_frequency = FLAGS.validation_frequency |
|
|
|
self.target_network_lag = FLAGS.target_network_lag |
|
self.sample_from = FLAGS.sample_from |
|
assert self.sample_from in ['online', 'target'] |
|
|
|
self.critic_weight = FLAGS.critic_weight |
|
self.objective = FLAGS.objective |
|
self.trust_region_p = FLAGS.trust_region_p |
|
self.value_opt = FLAGS.value_opt |
|
assert not self.trust_region_p or self.objective in ['pcl', 'trpo'] |
|
assert self.objective != 'trpo' or self.trust_region_p |
|
assert self.value_opt is None or self.value_opt == 'None' or \ |
|
self.critic_weight == 0.0 |
|
self.max_divergence = FLAGS.max_divergence |
|
|
|
self.learning_rate = FLAGS.learning_rate |
|
self.clip_norm = FLAGS.clip_norm |
|
self.clip_adv = FLAGS.clip_adv |
|
self.tau = FLAGS.tau |
|
self.tau_decay = FLAGS.tau_decay |
|
self.tau_start = FLAGS.tau_start |
|
self.eps_lambda = FLAGS.eps_lambda |
|
self.update_eps_lambda = FLAGS.update_eps_lambda |
|
self.gamma = FLAGS.gamma |
|
self.rollout = FLAGS.rollout |
|
self.use_target_values = FLAGS.use_target_values |
|
self.fixed_std = FLAGS.fixed_std |
|
self.input_prev_actions = FLAGS.input_prev_actions |
|
self.recurrent = FLAGS.recurrent |
|
assert not self.trust_region_p or not self.recurrent |
|
self.input_time_step = FLAGS.input_time_step |
|
assert not self.input_time_step or (self.cutoff_agent <= self.max_step) |
|
|
|
self.use_online_batch = FLAGS.use_online_batch |
|
self.batch_by_steps = FLAGS.batch_by_steps |
|
self.unify_episodes = FLAGS.unify_episodes |
|
if self.unify_episodes: |
|
assert self.batch_size == 1 |
|
|
|
self.replay_buffer_size = FLAGS.replay_buffer_size |
|
self.replay_buffer_alpha = FLAGS.replay_buffer_alpha |
|
self.replay_buffer_freq = FLAGS.replay_buffer_freq |
|
assert self.replay_buffer_freq in [-1, 0, 1] |
|
self.eviction = FLAGS.eviction |
|
self.prioritize_by = FLAGS.prioritize_by |
|
assert self.prioritize_by in ['rewards', 'step'] |
|
self.num_expert_paths = FLAGS.num_expert_paths |
|
|
|
self.internal_dim = FLAGS.internal_dim |
|
self.value_hidden_layers = FLAGS.value_hidden_layers |
|
self.tf_seed = FLAGS.tf_seed |
|
|
|
self.save_trajectories_dir = FLAGS.save_trajectories_dir |
|
self.save_trajectories_file = ( |
|
os.path.join( |
|
self.save_trajectories_dir, self.env_str.replace('-', '_')) |
|
if self.save_trajectories_dir else None) |
|
self.load_trajectories_file = FLAGS.load_trajectories_file |
|
|
|
self.hparams = dict((attr, getattr(self, attr)) |
|
for attr in dir(self) |
|
if not attr.startswith('__') and |
|
not callable(getattr(self, attr))) |
|
|
|
def hparams_string(self): |
|
return '\n'.join('%s: %s' % item for item in sorted(self.hparams.items())) |
|
|
|
def get_objective(self): |
|
tau = self.tau |
|
if self.tau_decay is not None: |
|
assert self.tau_start >= self.tau |
|
tau = tf.maximum( |
|
tf.train.exponential_decay( |
|
self.tau_start, self.global_step, 100, self.tau_decay), |
|
self.tau) |
|
|
|
if self.objective in ['pcl', 'a3c', 'trpo', 'upcl']: |
|
cls = (objective.PCL if self.objective in ['pcl', 'upcl'] else |
|
objective.TRPO if self.objective == 'trpo' else |
|
objective.ActorCritic) |
|
policy_weight = 1.0 |
|
|
|
return cls(self.learning_rate, |
|
clip_norm=self.clip_norm, |
|
policy_weight=policy_weight, |
|
critic_weight=self.critic_weight, |
|
tau=tau, gamma=self.gamma, rollout=self.rollout, |
|
eps_lambda=self.eps_lambda, clip_adv=self.clip_adv, |
|
use_target_values=self.use_target_values) |
|
elif self.objective in ['reinforce', 'urex']: |
|
cls = (full_episode_objective.Reinforce |
|
if self.objective == 'reinforce' else |
|
full_episode_objective.UREX) |
|
return cls(self.learning_rate, |
|
clip_norm=self.clip_norm, |
|
num_samples=self.num_samples, |
|
tau=tau, bonus_weight=1.0) |
|
else: |
|
assert False, 'Unknown objective %s' % self.objective |
|
|
|
def get_policy(self): |
|
if self.recurrent: |
|
cls = policy.Policy |
|
else: |
|
cls = policy.MLPPolicy |
|
return cls(self.env_spec, self.internal_dim, |
|
fixed_std=self.fixed_std, |
|
recurrent=self.recurrent, |
|
input_prev_actions=self.input_prev_actions) |
|
|
|
def get_baseline(self): |
|
cls = (baseline.UnifiedBaseline if self.objective == 'upcl' else |
|
baseline.Baseline) |
|
return cls(self.env_spec, self.internal_dim, |
|
input_prev_actions=self.input_prev_actions, |
|
input_time_step=self.input_time_step, |
|
input_policy_state=self.recurrent, |
|
n_hidden_layers=self.value_hidden_layers, |
|
hidden_dim=self.internal_dim, |
|
tau=self.tau) |
|
|
|
def get_trust_region_p_opt(self): |
|
if self.trust_region_p: |
|
return trust_region.TrustRegionOptimization( |
|
max_divergence=self.max_divergence) |
|
else: |
|
return None |
|
|
|
def get_value_opt(self): |
|
if self.value_opt == 'grad': |
|
return optimizers.GradOptimization( |
|
learning_rate=self.learning_rate, max_iter=5, mix_frac=0.05) |
|
elif self.value_opt == 'lbfgs': |
|
return optimizers.LbfgsOptimization(max_iter=25, mix_frac=0.1) |
|
elif self.value_opt == 'best_fit': |
|
return optimizers.BestFitOptimization(mix_frac=1.0) |
|
else: |
|
return None |
|
|
|
def get_model(self): |
|
cls = model.Model |
|
return cls(self.env_spec, self.global_step, |
|
target_network_lag=self.target_network_lag, |
|
sample_from=self.sample_from, |
|
get_policy=self.get_policy, |
|
get_baseline=self.get_baseline, |
|
get_objective=self.get_objective, |
|
get_trust_region_p_opt=self.get_trust_region_p_opt, |
|
get_value_opt=self.get_value_opt) |
|
|
|
def get_replay_buffer(self): |
|
if self.replay_buffer_freq <= 0: |
|
return None |
|
else: |
|
assert self.objective in ['pcl', 'upcl'], 'Can\'t use replay buffer with %s' % ( |
|
self.objective) |
|
cls = replay_buffer.PrioritizedReplayBuffer |
|
return cls(self.replay_buffer_size, |
|
alpha=self.replay_buffer_alpha, |
|
eviction_strategy=self.eviction) |
|
|
|
def get_buffer_seeds(self): |
|
return expert_paths.sample_expert_paths( |
|
self.num_expert_paths, self.env_str, self.env_spec, |
|
load_trajectories_file=self.load_trajectories_file) |
|
|
|
def get_controller(self, env): |
|
"""Get controller.""" |
|
cls = controller.Controller |
|
return cls(env, self.env_spec, self.internal_dim, |
|
use_online_batch=self.use_online_batch, |
|
batch_by_steps=self.batch_by_steps, |
|
unify_episodes=self.unify_episodes, |
|
replay_batch_size=self.replay_batch_size, |
|
max_step=self.max_step, |
|
cutoff_agent=self.cutoff_agent, |
|
save_trajectories_file=self.save_trajectories_file, |
|
use_trust_region=self.trust_region_p, |
|
use_value_opt=self.value_opt not in [None, 'None'], |
|
update_eps_lambda=self.update_eps_lambda, |
|
prioritize_by=self.prioritize_by, |
|
get_model=self.get_model, |
|
get_replay_buffer=self.get_replay_buffer, |
|
get_buffer_seeds=self.get_buffer_seeds) |
|
|
|
def do_before_step(self, step): |
|
pass |
|
|
|
def run(self): |
|
"""Run training.""" |
|
is_chief = FLAGS.task_id == 0 or not FLAGS.supervisor |
|
sv = None |
|
|
|
def init_fn(sess, saver): |
|
ckpt = None |
|
if FLAGS.save_dir and sv is None: |
|
load_dir = FLAGS.save_dir |
|
ckpt = tf.train.get_checkpoint_state(load_dir) |
|
if ckpt and ckpt.model_checkpoint_path: |
|
logging.info('restoring from %s', ckpt.model_checkpoint_path) |
|
saver.restore(sess, ckpt.model_checkpoint_path) |
|
elif FLAGS.load_path: |
|
logging.info('restoring from %s', FLAGS.load_path) |
|
saver.restore(sess, FLAGS.load_path) |
|
|
|
if FLAGS.supervisor: |
|
with tf.device(tf.ReplicaDeviceSetter(FLAGS.ps_tasks, merge_devices=True)): |
|
self.global_step = tf.contrib.framework.get_or_create_global_step() |
|
tf.set_random_seed(FLAGS.tf_seed) |
|
self.controller = self.get_controller(self.env) |
|
self.model = self.controller.model |
|
self.controller.setup() |
|
with tf.variable_scope(tf.get_variable_scope(), reuse=True): |
|
self.eval_controller = self.get_controller(self.eval_env) |
|
self.eval_controller.setup(train=False) |
|
|
|
saver = tf.train.Saver(max_to_keep=10) |
|
step = self.model.global_step |
|
sv = tf.Supervisor(logdir=FLAGS.save_dir, |
|
is_chief=is_chief, |
|
saver=saver, |
|
save_model_secs=600, |
|
summary_op=None, |
|
save_summaries_secs=60, |
|
global_step=step, |
|
init_fn=lambda sess: init_fn(sess, saver)) |
|
sess = sv.PrepareSession(FLAGS.master) |
|
else: |
|
tf.set_random_seed(FLAGS.tf_seed) |
|
self.global_step = tf.contrib.framework.get_or_create_global_step() |
|
self.controller = self.get_controller(self.env) |
|
self.model = self.controller.model |
|
self.controller.setup() |
|
with tf.variable_scope(tf.get_variable_scope(), reuse=True): |
|
self.eval_controller = self.get_controller(self.eval_env) |
|
self.eval_controller.setup(train=False) |
|
|
|
saver = tf.train.Saver(max_to_keep=10) |
|
sess = tf.Session() |
|
sess.run(tf.initialize_all_variables()) |
|
init_fn(sess, saver) |
|
|
|
self.sv = sv |
|
self.sess = sess |
|
|
|
logging.info('hparams:\n%s', self.hparams_string()) |
|
|
|
model_step = sess.run(self.model.global_step) |
|
if model_step >= self.num_steps: |
|
logging.info('training has reached final step') |
|
return |
|
|
|
losses = [] |
|
rewards = [] |
|
all_ep_rewards = [] |
|
for step in xrange(1 + self.num_steps): |
|
|
|
if sv is not None and sv.ShouldStop(): |
|
logging.info('stopping supervisor') |
|
break |
|
|
|
self.do_before_step(step) |
|
|
|
(loss, summary, |
|
total_rewards, episode_rewards) = self.controller.train(sess) |
|
_, greedy_episode_rewards = self.eval_controller.eval(sess) |
|
self.controller.greedy_episode_rewards = greedy_episode_rewards |
|
losses.append(loss) |
|
rewards.append(total_rewards) |
|
all_ep_rewards.extend(episode_rewards) |
|
|
|
if (random.random() < 0.1 and summary and episode_rewards and |
|
is_chief and sv and sv._summary_writer): |
|
sv.summary_computed(sess, summary) |
|
|
|
model_step = sess.run(self.model.global_step) |
|
if is_chief and step % self.validation_frequency == 0: |
|
logging.info('at training step %d, model step %d: ' |
|
'avg loss %f, avg reward %f, ' |
|
'episode rewards: %f, greedy rewards: %f', |
|
step, model_step, |
|
np.mean(losses), np.mean(rewards), |
|
np.mean(all_ep_rewards), |
|
np.mean(greedy_episode_rewards)) |
|
|
|
losses = [] |
|
rewards = [] |
|
all_ep_rewards = [] |
|
|
|
if model_step >= self.num_steps: |
|
logging.info('training has reached final step') |
|
break |
|
|
|
if is_chief and sv is not None: |
|
logging.info('saving final model to %s', sv.save_path) |
|
sv.saver.save(sess, sv.save_path, global_step=sv.global_step) |
|
|
|
|
|
def main(unused_argv): |
|
logging.set_verbosity(logging.INFO) |
|
trainer = Trainer() |
|
trainer.run() |
|
|
|
|
|
if __name__ == '__main__': |
|
app.run() |
|
|