|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Baseline model for value estimates. |
|
|
|
Implements the value component of the neural network. |
|
In some cases this is just an additional linear layer on the policy. |
|
In other cases, it is a completely separate neural network. |
|
""" |
|
|
|
from six.moves import xrange |
|
import tensorflow as tf |
|
import numpy as np |
|
|
|
|
|
class Baseline(object): |
|
def __init__(self, env_spec, internal_policy_dim, |
|
input_prev_actions=True, |
|
input_time_step=False, |
|
input_policy_state=True, |
|
n_hidden_layers=0, |
|
hidden_dim=64, |
|
tau=0.0): |
|
self.env_spec = env_spec |
|
self.internal_policy_dim = internal_policy_dim |
|
self.input_prev_actions = input_prev_actions |
|
self.input_time_step = input_time_step |
|
self.input_policy_state = input_policy_state |
|
self.n_hidden_layers = n_hidden_layers |
|
self.hidden_dim = hidden_dim |
|
self.tau = tau |
|
|
|
self.matrix_init = tf.truncated_normal_initializer(stddev=0.01) |
|
|
|
def get_inputs(self, time_step, obs, prev_actions, |
|
internal_policy_states): |
|
"""Get inputs to network as single tensor.""" |
|
inputs = [tf.ones_like(time_step)] |
|
input_dim = 1 |
|
|
|
if not self.input_policy_state: |
|
for i, (obs_dim, obs_type) in enumerate(self.env_spec.obs_dims_and_types): |
|
if self.env_spec.is_discrete(obs_type): |
|
inputs.append( |
|
tf.one_hot(obs[i], obs_dim)) |
|
input_dim += obs_dim |
|
elif self.env_spec.is_box(obs_type): |
|
cur_obs = obs[i] |
|
inputs.append(cur_obs) |
|
inputs.append(cur_obs ** 2) |
|
input_dim += obs_dim * 2 |
|
else: |
|
assert False |
|
|
|
if self.input_prev_actions: |
|
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types): |
|
if self.env_spec.is_discrete(act_type): |
|
inputs.append( |
|
tf.one_hot(prev_actions[i], act_dim)) |
|
input_dim += act_dim |
|
elif self.env_spec.is_box(act_type): |
|
inputs.append(prev_actions[i]) |
|
input_dim += act_dim |
|
else: |
|
assert False |
|
|
|
if self.input_policy_state: |
|
inputs.append(internal_policy_states) |
|
input_dim += self.internal_policy_dim |
|
|
|
if self.input_time_step: |
|
scaled_time = 0.01 * time_step |
|
inputs.extend([scaled_time, scaled_time ** 2, scaled_time ** 3]) |
|
input_dim += 3 |
|
|
|
return input_dim, tf.concat(inputs, 1) |
|
|
|
def reshape_batched_inputs(self, all_obs, all_actions, |
|
internal_policy_states, policy_logits): |
|
"""Reshape inputs from [time_length, batch_size, ...] to |
|
[time_length * batch_size, ...]. |
|
|
|
This allows for computing the value estimate in one go. |
|
""" |
|
batch_size = tf.shape(all_obs[0])[1] |
|
time_length = tf.shape(all_obs[0])[0] |
|
|
|
reshaped_obs = [] |
|
for obs, (obs_dim, obs_type) in zip(all_obs, self.env_spec.obs_dims_and_types): |
|
if self.env_spec.is_discrete(obs_type): |
|
reshaped_obs.append(tf.reshape(obs, [time_length * batch_size])) |
|
elif self.env_spec.is_box(obs_type): |
|
reshaped_obs.append(tf.reshape(obs, [time_length * batch_size, obs_dim])) |
|
|
|
reshaped_prev_act = [] |
|
reshaped_policy_logits = [] |
|
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types): |
|
prev_act = all_actions[i] |
|
if self.env_spec.is_discrete(act_type): |
|
reshaped_prev_act.append( |
|
tf.reshape(prev_act, [time_length * batch_size])) |
|
elif self.env_spec.is_box(act_type): |
|
reshaped_prev_act.append( |
|
tf.reshape(prev_act, [time_length * batch_size, act_dim])) |
|
|
|
reshaped_policy_logits.append( |
|
tf.reshape(policy_logits[i], [time_length * batch_size, -1])) |
|
|
|
reshaped_internal_policy_states = tf.reshape( |
|
internal_policy_states, |
|
[time_length * batch_size, self.internal_policy_dim]) |
|
|
|
time_step = (float(self.input_time_step) * |
|
tf.expand_dims( |
|
tf.to_float(tf.range(time_length * batch_size) / |
|
batch_size), -1)) |
|
|
|
return (time_step, reshaped_obs, reshaped_prev_act, |
|
reshaped_internal_policy_states, |
|
reshaped_policy_logits) |
|
|
|
def get_values(self, all_obs, all_actions, internal_policy_states, |
|
policy_logits): |
|
"""Get value estimates given input.""" |
|
batch_size = tf.shape(all_obs[0])[1] |
|
time_length = tf.shape(all_obs[0])[0] |
|
|
|
(time_step, reshaped_obs, reshaped_prev_act, |
|
reshaped_internal_policy_states, |
|
reshaped_policy_logits) = self.reshape_batched_inputs( |
|
all_obs, all_actions, internal_policy_states, policy_logits) |
|
|
|
input_dim, inputs = self.get_inputs( |
|
time_step, reshaped_obs, reshaped_prev_act, |
|
reshaped_internal_policy_states) |
|
|
|
for depth in xrange(self.n_hidden_layers): |
|
with tf.variable_scope('value_layer%d' % depth): |
|
w = tf.get_variable('w', [input_dim, self.hidden_dim]) |
|
inputs = tf.nn.tanh(tf.matmul(inputs, w)) |
|
input_dim = self.hidden_dim |
|
|
|
w_v = tf.get_variable('w_v', [input_dim, 1], |
|
initializer=self.matrix_init) |
|
values = tf.matmul(inputs, w_v) |
|
values = tf.reshape(values, [time_length, batch_size]) |
|
|
|
inputs = inputs[:-batch_size] |
|
return values, inputs, w_v |
|
|
|
|
|
class UnifiedBaseline(Baseline): |
|
"""Baseline for Unified PCL.""" |
|
|
|
def get_values(self, all_obs, all_actions, internal_policy_states, |
|
policy_logits): |
|
batch_size = tf.shape(all_obs[0])[1] |
|
time_length = tf.shape(all_obs[0])[0] |
|
|
|
(time_step, reshaped_obs, reshaped_prev_act, |
|
reshaped_internal_policy_states, |
|
reshaped_policy_logits) = self.reshape_batched_inputs( |
|
all_obs, all_actions, internal_policy_states, policy_logits) |
|
|
|
def f_transform(q_values, tau): |
|
max_q = tf.reduce_max(q_values, -1, keep_dims=True) |
|
return tf.squeeze(max_q, [-1]) + tau * tf.log( |
|
tf.reduce_sum(tf.exp((q_values - max_q) / tau), -1)) |
|
|
|
assert len(reshaped_policy_logits) == 1 |
|
values = f_transform((self.tau + self.eps_lambda) * reshaped_policy_logits[0], |
|
(self.tau + self.eps_lambda)) |
|
values = tf.reshape(values, [time_length, batch_size]) |
|
|
|
|
|
input_dim, inputs = self.get_inputs( |
|
time_step, reshaped_obs, reshaped_prev_act, |
|
reshaped_internal_policy_states) |
|
|
|
w_v = tf.get_variable('w_v', [input_dim, 1], |
|
initializer=self.matrix_init) |
|
|
|
inputs = inputs[:-batch_size] |
|
|
|
return values, inputs, w_v |
|
|