|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
r"""Script for training an RL agent using the UVF algorithm. |
|
|
|
To run locally: See run_train.py |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import os |
|
import time |
|
import tensorflow as tf |
|
slim = tf.contrib.slim |
|
|
|
import gin.tf |
|
|
|
import train_utils |
|
import agent as agent_ |
|
from agents import circular_buffer |
|
from utils import utils as uvf_utils |
|
from environments import create_maze_env |
|
|
|
|
|
|
|
flags = tf.app.flags |
|
|
|
FLAGS = flags.FLAGS |
|
flags.DEFINE_string('goal_sample_strategy', 'sample', |
|
'None, sample, FuN') |
|
|
|
LOAD_PATH = None |
|
|
|
|
|
def collect_experience(tf_env, agent, meta_agent, state_preprocess, |
|
replay_buffer, meta_replay_buffer, |
|
action_fn, meta_action_fn, |
|
environment_steps, num_episodes, num_resets, |
|
episode_rewards, episode_meta_rewards, |
|
store_context, |
|
disable_agent_reset): |
|
"""Collect experience in a tf_env into a replay_buffer using action_fn. |
|
|
|
Args: |
|
tf_env: A TFEnvironment. |
|
agent: A UVF agent. |
|
meta_agent: A Meta Agent. |
|
replay_buffer: A Replay buffer to collect experience in. |
|
meta_replay_buffer: A Replay buffer to collect meta agent experience in. |
|
action_fn: A function to produce actions given current state. |
|
meta_action_fn: A function to produce meta actions given current state. |
|
environment_steps: A variable to count the number of steps in the tf_env. |
|
num_episodes: A variable to count the number of episodes. |
|
num_resets: A variable to count the number of resets. |
|
store_context: A boolean to check if store context in replay. |
|
disable_agent_reset: A boolean that disables agent from resetting. |
|
|
|
Returns: |
|
A collect_experience_op that excute an action and store into the |
|
replay_buffers |
|
""" |
|
tf_env.start_collect() |
|
state = tf_env.current_obs() |
|
state_repr = state_preprocess(state) |
|
action = action_fn(state, context=None) |
|
|
|
with tf.control_dependencies([state]): |
|
transition_type, reward, discount = tf_env.step(action) |
|
|
|
def increment_step(): |
|
return environment_steps.assign_add(1) |
|
|
|
def increment_episode(): |
|
return num_episodes.assign_add(1) |
|
|
|
def increment_reset(): |
|
return num_resets.assign_add(1) |
|
|
|
def update_episode_rewards(context_reward, meta_reward, reset): |
|
new_episode_rewards = tf.concat( |
|
[episode_rewards[:1] + context_reward, episode_rewards[1:]], 0) |
|
new_episode_meta_rewards = tf.concat( |
|
[episode_meta_rewards[:1] + meta_reward, |
|
episode_meta_rewards[1:]], 0) |
|
return tf.group( |
|
episode_rewards.assign( |
|
tf.cond(reset, |
|
lambda: tf.concat([[0.], episode_rewards[:-1]], 0), |
|
lambda: new_episode_rewards)), |
|
episode_meta_rewards.assign( |
|
tf.cond(reset, |
|
lambda: tf.concat([[0.], episode_meta_rewards[:-1]], 0), |
|
lambda: new_episode_meta_rewards))) |
|
|
|
def no_op_int(): |
|
return tf.constant(0, dtype=tf.int64) |
|
|
|
step_cond = agent.step_cond_fn(state, action, |
|
transition_type, |
|
environment_steps, num_episodes) |
|
reset_episode_cond = agent.reset_episode_cond_fn( |
|
state, action, |
|
transition_type, environment_steps, num_episodes) |
|
reset_env_cond = agent.reset_env_cond_fn(state, action, |
|
transition_type, |
|
environment_steps, num_episodes) |
|
|
|
increment_step_op = tf.cond(step_cond, increment_step, no_op_int) |
|
increment_episode_op = tf.cond(reset_episode_cond, increment_episode, |
|
no_op_int) |
|
increment_reset_op = tf.cond(reset_env_cond, increment_reset, no_op_int) |
|
increment_op = tf.group(increment_step_op, increment_episode_op, |
|
increment_reset_op) |
|
|
|
with tf.control_dependencies([increment_op, reward, discount]): |
|
next_state = tf_env.current_obs() |
|
next_state_repr = state_preprocess(next_state) |
|
next_reset_episode_cond = tf.logical_or( |
|
agent.reset_episode_cond_fn( |
|
state, action, |
|
transition_type, environment_steps, num_episodes), |
|
tf.equal(discount, 0.0)) |
|
|
|
if store_context: |
|
context = [tf.identity(var) + tf.zeros_like(var) for var in agent.context_vars] |
|
meta_context = [tf.identity(var) + tf.zeros_like(var) for var in meta_agent.context_vars] |
|
else: |
|
context = [] |
|
meta_context = [] |
|
with tf.control_dependencies([next_state] + context + meta_context): |
|
if disable_agent_reset: |
|
collect_experience_ops = [tf.no_op()] |
|
else: |
|
collect_experience_ops = agent.cond_begin_episode_op( |
|
tf.logical_not(reset_episode_cond), |
|
[state, action, reward, next_state, |
|
state_repr, next_state_repr], |
|
mode='explore', meta_action_fn=meta_action_fn) |
|
context_reward, meta_reward = collect_experience_ops |
|
collect_experience_ops = list(collect_experience_ops) |
|
collect_experience_ops.append( |
|
update_episode_rewards(tf.reduce_sum(context_reward), meta_reward, |
|
reset_episode_cond)) |
|
|
|
meta_action_every_n = agent.tf_context.meta_action_every_n |
|
with tf.control_dependencies(collect_experience_ops): |
|
transition = [state, action, reward, discount, next_state] |
|
|
|
meta_action = tf.to_float( |
|
tf.concat(context, -1)) |
|
|
|
meta_end = tf.logical_and( |
|
tf.equal(agent.tf_context.t % meta_action_every_n, 1), |
|
agent.tf_context.t > 1) |
|
with tf.variable_scope(tf.get_variable_scope(), reuse=tf.AUTO_REUSE): |
|
states_var = tf.get_variable('states_var', |
|
[meta_action_every_n, state.shape[-1]], |
|
state.dtype) |
|
actions_var = tf.get_variable('actions_var', |
|
[meta_action_every_n, action.shape[-1]], |
|
action.dtype) |
|
state_var = tf.get_variable('state_var', state.shape, state.dtype) |
|
reward_var = tf.get_variable('reward_var', reward.shape, reward.dtype) |
|
meta_action_var = tf.get_variable('meta_action_var', |
|
meta_action.shape, meta_action.dtype) |
|
meta_context_var = [ |
|
tf.get_variable('meta_context_var%d' % idx, |
|
meta_context[idx].shape, meta_context[idx].dtype) |
|
for idx in range(len(meta_context))] |
|
|
|
actions_var_upd = tf.scatter_update( |
|
actions_var, (agent.tf_context.t - 2) % meta_action_every_n, action) |
|
with tf.control_dependencies([actions_var_upd]): |
|
actions = tf.identity(actions_var) + tf.zeros_like(actions_var) |
|
meta_reward = tf.identity(meta_reward) + tf.zeros_like(meta_reward) |
|
meta_reward = tf.reshape(meta_reward, reward.shape) |
|
|
|
reward = 0.1 * meta_reward |
|
meta_transition = [state_var, meta_action_var, |
|
reward_var + reward, |
|
discount * (1 - tf.to_float(next_reset_episode_cond)), |
|
next_state] |
|
meta_transition.extend([states_var, actions]) |
|
if store_context: |
|
transition += context + list(agent.context_vars) |
|
meta_transition += meta_context_var + list(meta_agent.context_vars) |
|
|
|
meta_step_cond = tf.squeeze(tf.logical_and(step_cond, tf.logical_or(next_reset_episode_cond, meta_end))) |
|
|
|
collect_experience_op = tf.group( |
|
replay_buffer.maybe_add(transition, step_cond), |
|
meta_replay_buffer.maybe_add(meta_transition, meta_step_cond), |
|
) |
|
|
|
with tf.control_dependencies([collect_experience_op]): |
|
collect_experience_op = tf.cond(reset_env_cond, |
|
tf_env.reset, |
|
tf_env.current_time_step) |
|
|
|
meta_period = tf.equal(agent.tf_context.t % meta_action_every_n, 1) |
|
states_var_upd = tf.scatter_update( |
|
states_var, (agent.tf_context.t - 1) % meta_action_every_n, |
|
next_state) |
|
state_var_upd = tf.assign( |
|
state_var, |
|
tf.cond(meta_period, lambda: next_state, lambda: state_var)) |
|
reward_var_upd = tf.assign( |
|
reward_var, |
|
tf.cond(meta_period, |
|
lambda: tf.zeros_like(reward_var), |
|
lambda: reward_var + reward)) |
|
meta_action = tf.to_float(tf.concat(agent.context_vars, -1)) |
|
meta_action_var_upd = tf.assign( |
|
meta_action_var, |
|
tf.cond(meta_period, lambda: meta_action, lambda: meta_action_var)) |
|
meta_context_var_upd = [ |
|
tf.assign( |
|
meta_context_var[idx], |
|
tf.cond(meta_period, |
|
lambda: meta_agent.context_vars[idx], |
|
lambda: meta_context_var[idx])) |
|
for idx in range(len(meta_context))] |
|
|
|
return tf.group( |
|
collect_experience_op, |
|
states_var_upd, |
|
state_var_upd, |
|
reward_var_upd, |
|
meta_action_var_upd, |
|
*meta_context_var_upd) |
|
|
|
|
|
def sample_best_meta_actions(state_reprs, next_state_reprs, prev_meta_actions, |
|
low_states, low_actions, low_state_reprs, |
|
inverse_dynamics, uvf_agent, k=10): |
|
"""Return meta-actions which approximately maximize low-level log-probs.""" |
|
sampled_actions = inverse_dynamics.sample(state_reprs, next_state_reprs, k, prev_meta_actions) |
|
sampled_actions = tf.stop_gradient(sampled_actions) |
|
sampled_log_probs = tf.reshape(uvf_agent.log_probs( |
|
tf.tile(low_states, [k, 1, 1]), |
|
tf.tile(low_actions, [k, 1, 1]), |
|
tf.tile(low_state_reprs, [k, 1, 1]), |
|
[tf.reshape(sampled_actions, [-1, sampled_actions.shape[-1]])]), |
|
[k, low_states.shape[0], |
|
low_states.shape[1], -1]) |
|
fitness = tf.reduce_sum(sampled_log_probs, [2, 3]) |
|
best_actions = tf.argmax(fitness, 0) |
|
actions = tf.gather_nd( |
|
sampled_actions, |
|
tf.stack([best_actions, |
|
tf.range(prev_meta_actions.shape[0], dtype=tf.int64)], -1)) |
|
return actions |
|
|
|
|
|
@gin.configurable |
|
def train_uvf(train_dir, |
|
environment=None, |
|
num_bin_actions=3, |
|
agent_class=None, |
|
meta_agent_class=None, |
|
state_preprocess_class=None, |
|
inverse_dynamics_class=None, |
|
exp_action_wrapper=None, |
|
replay_buffer=None, |
|
meta_replay_buffer=None, |
|
replay_num_steps=1, |
|
meta_replay_num_steps=1, |
|
critic_optimizer=None, |
|
actor_optimizer=None, |
|
meta_critic_optimizer=None, |
|
meta_actor_optimizer=None, |
|
repr_optimizer=None, |
|
relabel_contexts=False, |
|
meta_relabel_contexts=False, |
|
batch_size=64, |
|
repeat_size=0, |
|
num_episodes_train=2000, |
|
initial_episodes=2, |
|
initial_steps=None, |
|
num_updates_per_observation=1, |
|
num_collect_per_update=1, |
|
num_collect_per_meta_update=1, |
|
gamma=1.0, |
|
meta_gamma=1.0, |
|
reward_scale_factor=1.0, |
|
target_update_period=1, |
|
should_stop_early=None, |
|
clip_gradient_norm=0.0, |
|
summarize_gradients=False, |
|
debug_summaries=False, |
|
log_every_n_steps=100, |
|
prefetch_queue_capacity=2, |
|
policy_save_dir='policy', |
|
save_policy_every_n_steps=1000, |
|
save_policy_interval_secs=0, |
|
replay_context_ratio=0.0, |
|
next_state_as_context_ratio=0.0, |
|
state_index=0, |
|
zero_timer_ratio=0.0, |
|
timer_index=-1, |
|
debug=False, |
|
max_policies_to_save=None, |
|
max_steps_per_episode=None, |
|
load_path=LOAD_PATH): |
|
"""Train an agent.""" |
|
tf_env = create_maze_env.TFPyEnvironment(environment) |
|
observation_spec = [tf_env.observation_spec()] |
|
action_spec = [tf_env.action_spec()] |
|
|
|
max_steps_per_episode = max_steps_per_episode or tf_env.pyenv.max_episode_steps |
|
|
|
assert max_steps_per_episode, 'max_steps_per_episode need to be set' |
|
|
|
if initial_steps is None: |
|
initial_steps = initial_episodes * max_steps_per_episode |
|
|
|
if agent_class.ACTION_TYPE == 'discrete': |
|
assert False |
|
else: |
|
assert agent_class.ACTION_TYPE == 'continuous' |
|
|
|
assert agent_class.ACTION_TYPE == meta_agent_class.ACTION_TYPE |
|
with tf.variable_scope('meta_agent'): |
|
meta_agent = meta_agent_class( |
|
observation_spec, |
|
action_spec, |
|
tf_env, |
|
debug_summaries=debug_summaries) |
|
meta_agent.set_replay(replay=meta_replay_buffer) |
|
|
|
with tf.variable_scope('uvf_agent'): |
|
uvf_agent = agent_class( |
|
observation_spec, |
|
action_spec, |
|
tf_env, |
|
debug_summaries=debug_summaries) |
|
uvf_agent.set_meta_agent(agent=meta_agent) |
|
uvf_agent.set_replay(replay=replay_buffer) |
|
|
|
with tf.variable_scope('state_preprocess'): |
|
state_preprocess = state_preprocess_class() |
|
|
|
with tf.variable_scope('inverse_dynamics'): |
|
inverse_dynamics = inverse_dynamics_class( |
|
meta_agent.sub_context_as_action_specs[0]) |
|
|
|
|
|
global_step = tf.contrib.framework.get_or_create_global_step() |
|
num_episodes = tf.Variable(0, dtype=tf.int64, name='num_episodes') |
|
num_resets = tf.Variable(0, dtype=tf.int64, name='num_resets') |
|
num_updates = tf.Variable(0, dtype=tf.int64, name='num_updates') |
|
num_meta_updates = tf.Variable(0, dtype=tf.int64, name='num_meta_updates') |
|
episode_rewards = tf.Variable([0.] * 100, name='episode_rewards') |
|
episode_meta_rewards = tf.Variable([0.] * 100, name='episode_meta_rewards') |
|
|
|
|
|
train_utils.create_counter_summaries([ |
|
('environment_steps', global_step), |
|
('num_episodes', num_episodes), |
|
('num_resets', num_resets), |
|
('num_updates', num_updates), |
|
('num_meta_updates', num_meta_updates), |
|
('replay_buffer_adds', replay_buffer.get_num_adds()), |
|
('meta_replay_buffer_adds', meta_replay_buffer.get_num_adds()), |
|
]) |
|
|
|
tf.summary.scalar('avg_episode_rewards', |
|
tf.reduce_mean(episode_rewards[1:])) |
|
tf.summary.scalar('avg_episode_meta_rewards', |
|
tf.reduce_mean(episode_meta_rewards[1:])) |
|
tf.summary.histogram('episode_rewards', episode_rewards[1:]) |
|
tf.summary.histogram('episode_meta_rewards', episode_meta_rewards[1:]) |
|
|
|
|
|
action_fn = uvf_agent.action |
|
action_fn = uvf_agent.add_noise_fn(action_fn, global_step=None) |
|
meta_action_fn = meta_agent.action |
|
meta_action_fn = meta_agent.add_noise_fn(meta_action_fn, global_step=None) |
|
meta_actions_fn = meta_agent.actions |
|
meta_actions_fn = meta_agent.add_noise_fn(meta_actions_fn, global_step=None) |
|
init_collect_experience_op = collect_experience( |
|
tf_env, |
|
uvf_agent, |
|
meta_agent, |
|
state_preprocess, |
|
replay_buffer, |
|
meta_replay_buffer, |
|
action_fn, |
|
meta_action_fn, |
|
environment_steps=global_step, |
|
num_episodes=num_episodes, |
|
num_resets=num_resets, |
|
episode_rewards=episode_rewards, |
|
episode_meta_rewards=episode_meta_rewards, |
|
store_context=True, |
|
disable_agent_reset=False, |
|
) |
|
|
|
|
|
collect_experience_op = collect_experience( |
|
tf_env, |
|
uvf_agent, |
|
meta_agent, |
|
state_preprocess, |
|
replay_buffer, |
|
meta_replay_buffer, |
|
action_fn, |
|
meta_action_fn, |
|
environment_steps=global_step, |
|
num_episodes=num_episodes, |
|
num_resets=num_resets, |
|
episode_rewards=episode_rewards, |
|
episode_meta_rewards=episode_meta_rewards, |
|
store_context=True, |
|
disable_agent_reset=False, |
|
) |
|
|
|
train_op_list = [] |
|
repr_train_op = tf.constant(0.0) |
|
for mode in ['meta', 'nometa']: |
|
if mode == 'meta': |
|
agent = meta_agent |
|
buff = meta_replay_buffer |
|
critic_opt = meta_critic_optimizer |
|
actor_opt = meta_actor_optimizer |
|
relabel = meta_relabel_contexts |
|
num_steps = meta_replay_num_steps |
|
my_gamma = meta_gamma, |
|
n_updates = num_meta_updates |
|
else: |
|
agent = uvf_agent |
|
buff = replay_buffer |
|
critic_opt = critic_optimizer |
|
actor_opt = actor_optimizer |
|
relabel = relabel_contexts |
|
num_steps = replay_num_steps |
|
my_gamma = gamma |
|
n_updates = num_updates |
|
|
|
with tf.name_scope(mode): |
|
batch = buff.get_random_batch(batch_size, num_steps=num_steps) |
|
states, actions, rewards, discounts, next_states = batch[:5] |
|
with tf.name_scope('Reward'): |
|
tf.summary.scalar('average_step_reward', tf.reduce_mean(rewards)) |
|
rewards *= reward_scale_factor |
|
batch_queue = slim.prefetch_queue.prefetch_queue( |
|
[states, actions, rewards, discounts, next_states] + batch[5:], |
|
capacity=prefetch_queue_capacity, |
|
name='batch_queue') |
|
|
|
batch_dequeue = batch_queue.dequeue() |
|
if repeat_size > 0: |
|
batch_dequeue = [ |
|
tf.tile(batch, (repeat_size+1,) + (1,) * (batch.shape.ndims - 1)) |
|
for batch in batch_dequeue |
|
] |
|
batch_size *= (repeat_size + 1) |
|
states, actions, rewards, discounts, next_states = batch_dequeue[:5] |
|
if mode == 'meta': |
|
low_states = batch_dequeue[5] |
|
low_actions = batch_dequeue[6] |
|
low_state_reprs = state_preprocess(low_states) |
|
state_reprs = state_preprocess(states) |
|
next_state_reprs = state_preprocess(next_states) |
|
|
|
if mode == 'meta': |
|
prev_actions = actions |
|
if FLAGS.goal_sample_strategy == 'None': |
|
pass |
|
elif FLAGS.goal_sample_strategy == 'FuN': |
|
actions = inverse_dynamics.sample(state_reprs, next_state_reprs, 1, prev_actions, sc=0.1) |
|
actions = tf.stop_gradient(actions) |
|
elif FLAGS.goal_sample_strategy == 'sample': |
|
actions = sample_best_meta_actions(state_reprs, next_state_reprs, prev_actions, |
|
low_states, low_actions, low_state_reprs, |
|
inverse_dynamics, uvf_agent, k=10) |
|
else: |
|
assert False |
|
|
|
if state_preprocess.trainable and mode == 'meta': |
|
|
|
|
|
repr_loss, _, _ = state_preprocess.loss(states, next_states, low_actions, low_states) |
|
repr_train_op = slim.learning.create_train_op( |
|
repr_loss, |
|
repr_optimizer, |
|
global_step=None, |
|
update_ops=None, |
|
summarize_gradients=summarize_gradients, |
|
clip_gradient_norm=clip_gradient_norm, |
|
variables_to_train=state_preprocess.get_trainable_vars(),) |
|
|
|
|
|
contexts, next_contexts = agent.sample_contexts( |
|
mode='train', batch_size=batch_size, |
|
state=states, next_state=next_states, |
|
) |
|
if not relabel: |
|
contexts, next_contexts = ( |
|
batch_dequeue[-2*len(contexts):-1*len(contexts)], |
|
batch_dequeue[-1*len(contexts):]) |
|
|
|
merged_states = agent.merged_states(states, contexts) |
|
merged_next_states = agent.merged_states(next_states, next_contexts) |
|
if mode == 'nometa': |
|
context_rewards, context_discounts = agent.compute_rewards( |
|
'train', state_reprs, actions, rewards, next_state_reprs, contexts) |
|
elif mode == 'meta': |
|
_, context_discounts = agent.compute_rewards( |
|
'train', states, actions, rewards, next_states, contexts) |
|
context_rewards = rewards |
|
|
|
if agent.gamma_index is not None: |
|
context_discounts *= tf.cast( |
|
tf.reshape(contexts[agent.gamma_index], (-1,)), |
|
dtype=context_discounts.dtype) |
|
else: context_discounts *= my_gamma |
|
|
|
critic_loss = agent.critic_loss(merged_states, actions, |
|
context_rewards, context_discounts, |
|
merged_next_states) |
|
|
|
critic_loss = tf.reduce_mean(critic_loss) |
|
|
|
actor_loss = agent.actor_loss(merged_states, actions, |
|
context_rewards, context_discounts, |
|
merged_next_states) |
|
actor_loss *= tf.to_float( |
|
tf.equal(n_updates % target_update_period, 0)) |
|
|
|
critic_train_op = slim.learning.create_train_op( |
|
critic_loss, |
|
critic_opt, |
|
global_step=n_updates, |
|
update_ops=None, |
|
summarize_gradients=summarize_gradients, |
|
clip_gradient_norm=clip_gradient_norm, |
|
variables_to_train=agent.get_trainable_critic_vars(),) |
|
critic_train_op = uvf_utils.tf_print( |
|
critic_train_op, [critic_train_op], |
|
message='critic_loss', |
|
print_freq=1000, |
|
name='critic_loss') |
|
train_op_list.append(critic_train_op) |
|
if actor_loss is not None: |
|
actor_train_op = slim.learning.create_train_op( |
|
actor_loss, |
|
actor_opt, |
|
global_step=None, |
|
update_ops=None, |
|
summarize_gradients=summarize_gradients, |
|
clip_gradient_norm=clip_gradient_norm, |
|
variables_to_train=agent.get_trainable_actor_vars(),) |
|
actor_train_op = uvf_utils.tf_print( |
|
actor_train_op, [actor_train_op], |
|
message='actor_loss', |
|
print_freq=1000, |
|
name='actor_loss') |
|
train_op_list.append(actor_train_op) |
|
|
|
assert len(train_op_list) == 4 |
|
|
|
with tf.control_dependencies(train_op_list[2:]): |
|
update_targets_op = uvf_utils.periodically( |
|
uvf_agent.update_targets, target_update_period, 'update_targets') |
|
if meta_agent is not None: |
|
with tf.control_dependencies(train_op_list[:2]): |
|
update_meta_targets_op = uvf_utils.periodically( |
|
meta_agent.update_targets, target_update_period, 'update_targets') |
|
|
|
assert_op = tf.Assert( |
|
tf.less_equal(global_step, 200 + num_episodes_train * max_steps_per_episode), |
|
[global_step]) |
|
with tf.control_dependencies([update_targets_op, assert_op]): |
|
train_op = tf.add_n(train_op_list[2:], name='post_update_targets') |
|
|
|
train_op += repr_train_op |
|
with tf.control_dependencies([update_meta_targets_op, assert_op]): |
|
meta_train_op = tf.add_n(train_op_list[:2], |
|
name='post_update_meta_targets') |
|
|
|
if debug_summaries: |
|
train_.gen_debug_batch_summaries(batch) |
|
slim.summaries.add_histogram_summaries( |
|
uvf_agent.get_trainable_critic_vars(), 'critic_vars') |
|
slim.summaries.add_histogram_summaries( |
|
uvf_agent.get_trainable_actor_vars(), 'actor_vars') |
|
|
|
train_ops = train_utils.TrainOps(train_op, meta_train_op, |
|
collect_experience_op) |
|
|
|
policy_save_path = os.path.join(train_dir, policy_save_dir, 'model.ckpt') |
|
policy_vars = uvf_agent.get_actor_vars() + meta_agent.get_actor_vars() + [ |
|
global_step, num_episodes, num_resets |
|
] + list(uvf_agent.context_vars) + list(meta_agent.context_vars) + state_preprocess.get_trainable_vars() |
|
|
|
policy_vars += uvf_agent.get_trainable_critic_vars() + meta_agent.get_trainable_critic_vars() |
|
policy_saver = tf.train.Saver( |
|
policy_vars, max_to_keep=max_policies_to_save, sharded=False) |
|
|
|
lowlevel_vars = (uvf_agent.get_actor_vars() + |
|
uvf_agent.get_trainable_critic_vars() + |
|
state_preprocess.get_trainable_vars()) |
|
lowlevel_saver = tf.train.Saver(lowlevel_vars) |
|
|
|
def policy_save_fn(sess): |
|
policy_saver.save( |
|
sess, policy_save_path, global_step=global_step, write_meta_graph=False) |
|
if save_policy_interval_secs > 0: |
|
tf.logging.info( |
|
'Wait %d secs after save policy.' % save_policy_interval_secs) |
|
time.sleep(save_policy_interval_secs) |
|
|
|
train_step_fn = train_utils.TrainStep( |
|
max_number_of_steps=num_episodes_train * max_steps_per_episode + 100, |
|
num_updates_per_observation=num_updates_per_observation, |
|
num_collect_per_update=num_collect_per_update, |
|
num_collect_per_meta_update=num_collect_per_meta_update, |
|
log_every_n_steps=log_every_n_steps, |
|
policy_save_fn=policy_save_fn, |
|
save_policy_every_n_steps=save_policy_every_n_steps, |
|
should_stop_early=should_stop_early).train_step |
|
|
|
local_init_op = tf.local_variables_initializer() |
|
init_targets_op = tf.group(uvf_agent.update_targets(1.0), |
|
meta_agent.update_targets(1.0)) |
|
|
|
def initialize_training_fn(sess): |
|
"""Initialize training function.""" |
|
sess.run(local_init_op) |
|
sess.run(init_targets_op) |
|
if load_path: |
|
tf.logging.info('Restoring low-level from %s' % load_path) |
|
lowlevel_saver.restore(sess, load_path) |
|
global_step_value = sess.run(global_step) |
|
assert global_step_value == 0, 'Global step should be zero.' |
|
collect_experience_call = sess.make_callable( |
|
init_collect_experience_op) |
|
|
|
for _ in range(initial_steps): |
|
collect_experience_call() |
|
|
|
train_saver = tf.train.Saver(max_to_keep=2, sharded=True) |
|
tf.logging.info('train dir: %s', train_dir) |
|
return slim.learning.train( |
|
train_ops, |
|
train_dir, |
|
train_step_fn=train_step_fn, |
|
save_interval_secs=FLAGS.save_interval_secs, |
|
saver=train_saver, |
|
log_every_n_steps=0, |
|
global_step=global_step, |
|
master="", |
|
is_chief=(FLAGS.task == 0), |
|
save_summaries_secs=FLAGS.save_summaries_secs, |
|
init_fn=initialize_training_fn) |
|
|