|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Policy neural network. |
|
|
|
Implements network which takes in input and produces actions |
|
and log probabilities given a sampling distribution parameterization. |
|
""" |
|
|
|
import tensorflow as tf |
|
import numpy as np |
|
|
|
|
|
class Policy(object): |
|
def __init__(self, env_spec, internal_dim, |
|
fixed_std=True, recurrent=True, |
|
input_prev_actions=True): |
|
self.env_spec = env_spec |
|
self.internal_dim = internal_dim |
|
self.rnn_state_dim = self.internal_dim |
|
self.fixed_std = fixed_std |
|
self.recurrent = recurrent |
|
self.input_prev_actions = input_prev_actions |
|
|
|
self.matrix_init = tf.truncated_normal_initializer(stddev=0.01) |
|
self.vector_init = tf.constant_initializer(0.0) |
|
|
|
@property |
|
def input_dim(self): |
|
return (self.env_spec.total_obs_dim + |
|
self.env_spec.total_sampled_act_dim * self.input_prev_actions) |
|
|
|
@property |
|
def output_dim(self): |
|
return self.env_spec.total_sampling_act_dim |
|
|
|
def get_cell(self): |
|
"""Get RNN cell.""" |
|
self.cell_input_dim = self.internal_dim // 2 |
|
cell = tf.contrib.rnn.LSTMCell(self.cell_input_dim, |
|
state_is_tuple=False, |
|
reuse=tf.get_variable_scope().reuse) |
|
|
|
cell = tf.contrib.rnn.OutputProjectionWrapper( |
|
cell, self.output_dim, |
|
reuse=tf.get_variable_scope().reuse) |
|
|
|
return cell |
|
|
|
def core(self, obs, prev_internal_state, prev_actions): |
|
"""Core neural network taking in inputs and outputting sampling |
|
distribution parameters.""" |
|
batch_size = tf.shape(obs[0])[0] |
|
if not self.recurrent: |
|
prev_internal_state = tf.zeros([batch_size, self.rnn_state_dim]) |
|
|
|
cell = self.get_cell() |
|
|
|
b = tf.get_variable('input_bias', [self.cell_input_dim], |
|
initializer=self.vector_init) |
|
cell_input = tf.nn.bias_add(tf.zeros([batch_size, self.cell_input_dim]), b) |
|
|
|
for i, (obs_dim, obs_type) in enumerate(self.env_spec.obs_dims_and_types): |
|
w = tf.get_variable('w_state%d' % i, [obs_dim, self.cell_input_dim], |
|
initializer=self.matrix_init) |
|
if self.env_spec.is_discrete(obs_type): |
|
cell_input += tf.matmul(tf.one_hot(obs[i], obs_dim), w) |
|
elif self.env_spec.is_box(obs_type): |
|
cell_input += tf.matmul(obs[i], w) |
|
else: |
|
assert False |
|
|
|
if self.input_prev_actions: |
|
if self.env_spec.combine_actions: |
|
prev_action = prev_actions[0] |
|
for i, action_dim in enumerate(self.env_spec.orig_act_dims): |
|
act = tf.mod(prev_action, action_dim) |
|
w = tf.get_variable('w_prev_action%d' % i, [action_dim, self.cell_input_dim], |
|
initializer=self.matrix_init) |
|
cell_input += tf.matmul(tf.one_hot(act, action_dim), w) |
|
prev_action = tf.to_int32(prev_action / action_dim) |
|
else: |
|
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types): |
|
w = tf.get_variable('w_prev_action%d' % i, [act_dim, self.cell_input_dim], |
|
initializer=self.matrix_init) |
|
if self.env_spec.is_discrete(act_type): |
|
cell_input += tf.matmul(tf.one_hot(prev_actions[i], act_dim), w) |
|
elif self.env_spec.is_box(act_type): |
|
cell_input += tf.matmul(prev_actions[i], w) |
|
else: |
|
assert False |
|
|
|
output, next_state = cell(cell_input, prev_internal_state) |
|
|
|
return output, next_state |
|
|
|
def sample_action(self, logits, sampling_dim, |
|
act_dim, act_type, greedy=False): |
|
"""Sample an action from a distribution.""" |
|
if self.env_spec.is_discrete(act_type): |
|
if greedy: |
|
act = tf.argmax(logits, 1) |
|
else: |
|
act = tf.reshape(tf.multinomial(logits, 1), [-1]) |
|
elif self.env_spec.is_box(act_type): |
|
means = logits[:, :sampling_dim / 2] |
|
std = logits[:, sampling_dim / 2:] |
|
if greedy: |
|
act = means |
|
else: |
|
batch_size = tf.shape(logits)[0] |
|
act = means + std * tf.random_normal([batch_size, act_dim]) |
|
else: |
|
assert False |
|
|
|
return act |
|
|
|
def entropy(self, logits, |
|
sampling_dim, act_dim, act_type): |
|
"""Calculate entropy of distribution.""" |
|
if self.env_spec.is_discrete(act_type): |
|
entropy = tf.reduce_sum( |
|
-tf.nn.softmax(logits) * tf.nn.log_softmax(logits), -1) |
|
elif self.env_spec.is_box(act_type): |
|
means = logits[:, :sampling_dim / 2] |
|
std = logits[:, sampling_dim / 2:] |
|
entropy = tf.reduce_sum( |
|
0.5 * (1 + tf.log(2 * np.pi * tf.square(std))), -1) |
|
else: |
|
assert False |
|
|
|
return entropy |
|
|
|
def self_kl(self, logits, |
|
sampling_dim, act_dim, act_type): |
|
"""Calculate KL of distribution with itself. |
|
|
|
Used layer only for the gradients. |
|
""" |
|
|
|
if self.env_spec.is_discrete(act_type): |
|
probs = tf.nn.softmax(logits) |
|
log_probs = tf.nn.log_softmax(logits) |
|
self_kl = tf.reduce_sum( |
|
tf.stop_gradient(probs) * |
|
(tf.stop_gradient(log_probs) - log_probs), -1) |
|
elif self.env_spec.is_box(act_type): |
|
means = logits[:, :sampling_dim / 2] |
|
std = logits[:, sampling_dim / 2:] |
|
my_means = tf.stop_gradient(means) |
|
my_std = tf.stop_gradient(std) |
|
self_kl = tf.reduce_sum( |
|
tf.log(std / my_std) + |
|
(tf.square(my_std) + tf.square(my_means - means)) / |
|
(2.0 * tf.square(std)) - 0.5, |
|
-1) |
|
else: |
|
assert False |
|
|
|
return self_kl |
|
|
|
def log_prob_action(self, action, logits, |
|
sampling_dim, act_dim, act_type): |
|
"""Calculate log-prob of action sampled from distribution.""" |
|
if self.env_spec.is_discrete(act_type): |
|
act_log_prob = tf.reduce_sum( |
|
tf.one_hot(action, act_dim) * tf.nn.log_softmax(logits), -1) |
|
elif self.env_spec.is_box(act_type): |
|
means = logits[:, :sampling_dim / 2] |
|
std = logits[:, sampling_dim / 2:] |
|
act_log_prob = (- 0.5 * tf.log(2 * np.pi * tf.square(std)) |
|
- 0.5 * tf.square(action - means) / tf.square(std)) |
|
act_log_prob = tf.reduce_sum(act_log_prob, -1) |
|
else: |
|
assert False |
|
|
|
return act_log_prob |
|
|
|
def sample_actions(self, output, actions=None, greedy=False): |
|
"""Sample all actions given output of core network.""" |
|
sampled_actions = [] |
|
logits = [] |
|
log_probs = [] |
|
entropy = [] |
|
self_kl = [] |
|
|
|
start_idx = 0 |
|
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types): |
|
sampling_dim = self.env_spec.sampling_dim(act_dim, act_type) |
|
if self.fixed_std and self.env_spec.is_box(act_type): |
|
act_logits = output[:, start_idx:start_idx + act_dim] |
|
|
|
log_std = tf.get_variable('std%d' % i, [1, sampling_dim // 2]) |
|
|
|
act_logits = tf.concat( |
|
[act_logits, |
|
1e-6 + tf.exp(log_std) + 0 * act_logits], 1) |
|
else: |
|
act_logits = output[:, start_idx:start_idx + sampling_dim] |
|
|
|
if actions is None: |
|
act = self.sample_action(act_logits, sampling_dim, |
|
act_dim, act_type, |
|
greedy=greedy) |
|
else: |
|
act = actions[i] |
|
|
|
ent = self.entropy(act_logits, sampling_dim, act_dim, act_type) |
|
kl = self.self_kl(act_logits, sampling_dim, act_dim, act_type) |
|
|
|
act_log_prob = self.log_prob_action( |
|
act, act_logits, |
|
sampling_dim, act_dim, act_type) |
|
|
|
sampled_actions.append(act) |
|
logits.append(act_logits) |
|
log_probs.append(act_log_prob) |
|
entropy.append(ent) |
|
self_kl.append(kl) |
|
|
|
start_idx += sampling_dim |
|
|
|
assert start_idx == self.env_spec.total_sampling_act_dim |
|
|
|
return sampled_actions, logits, log_probs, entropy, self_kl |
|
|
|
def get_kl(self, my_logits, other_logits): |
|
"""Calculate KL between one policy output and another.""" |
|
kl = [] |
|
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types): |
|
sampling_dim = self.env_spec.sampling_dim(act_dim, act_type) |
|
single_my_logits = my_logits[i] |
|
single_other_logits = other_logits[i] |
|
if self.env_spec.is_discrete(act_type): |
|
my_probs = tf.nn.softmax(single_my_logits) |
|
my_log_probs = tf.nn.log_softmax(single_my_logits) |
|
other_log_probs = tf.nn.log_softmax(single_other_logits) |
|
my_kl = tf.reduce_sum(my_probs * (my_log_probs - other_log_probs), -1) |
|
elif self.env_spec.is_box(act_type): |
|
my_means = single_my_logits[:, :sampling_dim / 2] |
|
my_std = single_my_logits[:, sampling_dim / 2:] |
|
other_means = single_other_logits[:, :sampling_dim / 2] |
|
other_std = single_other_logits[:, sampling_dim / 2:] |
|
my_kl = tf.reduce_sum( |
|
tf.log(other_std / my_std) + |
|
(tf.square(my_std) + tf.square(my_means - other_means)) / |
|
(2.0 * tf.square(other_std)) - 0.5, |
|
-1) |
|
else: |
|
assert False |
|
|
|
kl.append(my_kl) |
|
|
|
return kl |
|
|
|
def single_step(self, prev, cur, greedy=False): |
|
"""Single RNN step. Equivalently, single-time-step sampled actions.""" |
|
prev_internal_state, prev_actions, _, _, _, _ = prev |
|
obs, actions = cur |
|
|
|
|
|
output, next_state = self.core( |
|
obs, prev_internal_state, prev_actions) |
|
|
|
|
|
(actions, logits, log_probs, |
|
entropy, self_kl) = self.sample_actions( |
|
output, actions=actions, greedy=greedy) |
|
|
|
return (next_state, tuple(actions), tuple(logits), tuple(log_probs), |
|
tuple(entropy), tuple(self_kl)) |
|
|
|
def sample_step(self, obs, prev_internal_state, prev_actions, greedy=False): |
|
"""Sample single step from policy.""" |
|
(next_state, sampled_actions, logits, log_probs, |
|
entropies, self_kls) = self.single_step( |
|
(prev_internal_state, prev_actions, None, None, None, None), |
|
(obs, None), greedy=greedy) |
|
return next_state, sampled_actions |
|
|
|
def multi_step(self, all_obs, initial_state, all_actions): |
|
"""Calculate log-probs and other calculations on batch of episodes.""" |
|
batch_size = tf.shape(initial_state)[0] |
|
time_length = tf.shape(all_obs[0])[0] |
|
initial_actions = [act[0] for act in all_actions] |
|
all_actions = [tf.concat([act[1:], act[0:1]], 0) |
|
for act in all_actions] |
|
|
|
(internal_states, _, logits, log_probs, |
|
entropies, self_kls) = tf.scan( |
|
self.single_step, |
|
(all_obs, all_actions), |
|
initializer=self.get_initializer( |
|
batch_size, initial_state, initial_actions)) |
|
|
|
|
|
log_probs = [log_prob[:-1] for log_prob in log_probs] |
|
entropies = [entropy[:-1] for entropy in entropies] |
|
self_kls = [self_kl[:-1] for self_kl in self_kls] |
|
|
|
return internal_states, logits, log_probs, entropies, self_kls |
|
|
|
def get_initializer(self, batch_size, initial_state, initial_actions): |
|
"""Get initializer for RNN.""" |
|
logits_init = [] |
|
log_probs_init = [] |
|
for act_dim, act_type in self.env_spec.act_dims_and_types: |
|
sampling_dim = self.env_spec.sampling_dim(act_dim, act_type) |
|
logits_init.append(tf.zeros([batch_size, sampling_dim])) |
|
log_probs_init.append(tf.zeros([batch_size])) |
|
entropy_init = [tf.zeros([batch_size]) for _ in self.env_spec.act_dims] |
|
self_kl_init = [tf.zeros([batch_size]) for _ in self.env_spec.act_dims] |
|
|
|
return (initial_state, |
|
tuple(initial_actions), |
|
tuple(logits_init), tuple(log_probs_init), |
|
tuple(entropy_init), |
|
tuple(self_kl_init)) |
|
|
|
def calculate_kl(self, my_logits, other_logits): |
|
"""Calculate KL between one policy and another on batch of episodes.""" |
|
batch_size = tf.shape(my_logits[0])[1] |
|
time_length = tf.shape(my_logits[0])[0] |
|
|
|
reshaped_my_logits = [ |
|
tf.reshape(my_logit, [batch_size * time_length, -1]) |
|
for my_logit in my_logits] |
|
reshaped_other_logits = [ |
|
tf.reshape(other_logit, [batch_size * time_length, -1]) |
|
for other_logit in other_logits] |
|
|
|
kl = self.get_kl(reshaped_my_logits, reshaped_other_logits) |
|
kl = [tf.reshape(kkl, [time_length, batch_size]) |
|
for kkl in kl] |
|
return kl |
|
|
|
|
|
class MLPPolicy(Policy): |
|
"""Non-recurrent policy.""" |
|
|
|
def get_cell(self): |
|
self.cell_input_dim = self.internal_dim |
|
|
|
def mlp(cell_input, prev_internal_state): |
|
w1 = tf.get_variable('w1', [self.cell_input_dim, self.internal_dim]) |
|
b1 = tf.get_variable('b1', [self.internal_dim]) |
|
|
|
w2 = tf.get_variable('w2', [self.internal_dim, self.internal_dim]) |
|
b2 = tf.get_variable('b2', [self.internal_dim]) |
|
|
|
w3 = tf.get_variable('w3', [self.internal_dim, self.internal_dim]) |
|
b3 = tf.get_variable('b3', [self.internal_dim]) |
|
|
|
proj = tf.get_variable( |
|
'proj', [self.internal_dim, self.output_dim]) |
|
|
|
hidden = cell_input |
|
hidden = tf.tanh(tf.nn.bias_add(tf.matmul(hidden, w1), b1)) |
|
hidden = tf.tanh(tf.nn.bias_add(tf.matmul(hidden, w2), b2)) |
|
|
|
output = tf.matmul(hidden, proj) |
|
|
|
return output, hidden |
|
|
|
return mlp |
|
|
|
def single_step(self, obs, actions, prev_actions, greedy=False): |
|
"""Single step.""" |
|
batch_size = tf.shape(obs[0])[0] |
|
prev_internal_state = tf.zeros([batch_size, self.internal_dim]) |
|
|
|
output, next_state = self.core( |
|
obs, prev_internal_state, prev_actions) |
|
|
|
|
|
(actions, logits, log_probs, |
|
entropy, self_kl) = self.sample_actions( |
|
output, actions=actions, greedy=greedy) |
|
|
|
return (next_state, tuple(actions), tuple(logits), tuple(log_probs), |
|
tuple(entropy), tuple(self_kl)) |
|
|
|
def sample_step(self, obs, prev_internal_state, prev_actions, greedy=False): |
|
"""Sample single step from policy.""" |
|
(next_state, sampled_actions, logits, log_probs, |
|
entropies, self_kls) = self.single_step(obs, None, prev_actions, |
|
greedy=greedy) |
|
return next_state, sampled_actions |
|
|
|
def multi_step(self, all_obs, initial_state, all_actions): |
|
"""Calculate log-probs and other calculations on batch of episodes.""" |
|
batch_size = tf.shape(initial_state)[0] |
|
time_length = tf.shape(all_obs[0])[0] |
|
|
|
|
|
reshaped_obs = [] |
|
for obs, (obs_dim, obs_type) in zip(all_obs, self.env_spec.obs_dims_and_types): |
|
if self.env_spec.is_discrete(obs_type): |
|
reshaped_obs.append(tf.reshape(obs, [time_length * batch_size])) |
|
elif self.env_spec.is_box(obs_type): |
|
reshaped_obs.append(tf.reshape(obs, [time_length * batch_size, obs_dim])) |
|
|
|
reshaped_act = [] |
|
reshaped_prev_act = [] |
|
for i, (act_dim, act_type) in enumerate(self.env_spec.act_dims_and_types): |
|
act = tf.concat([all_actions[i][1:], all_actions[i][0:1]], 0) |
|
prev_act = all_actions[i] |
|
if self.env_spec.is_discrete(act_type): |
|
reshaped_act.append(tf.reshape(act, [time_length * batch_size])) |
|
reshaped_prev_act.append( |
|
tf.reshape(prev_act, [time_length * batch_size])) |
|
elif self.env_spec.is_box(act_type): |
|
reshaped_act.append( |
|
tf.reshape(act, [time_length * batch_size, act_dim])) |
|
reshaped_prev_act.append( |
|
tf.reshape(prev_act, [time_length * batch_size, act_dim])) |
|
|
|
|
|
(internal_states, _, logits, log_probs, |
|
entropies, self_kls) = self.single_step( |
|
reshaped_obs, reshaped_act, reshaped_prev_act) |
|
|
|
|
|
internal_states = tf.reshape(internal_states, [time_length, batch_size, -1]) |
|
logits = [tf.reshape(logit, [time_length, batch_size, -1]) |
|
for logit in logits] |
|
log_probs = [tf.reshape(log_prob, [time_length, batch_size])[:-1] |
|
for log_prob in log_probs] |
|
entropies = [tf.reshape(ent, [time_length, batch_size])[:-1] |
|
for ent in entropies] |
|
self_kls = [tf.reshape(self_kl, [time_length, batch_size])[:-1] |
|
for self_kl in self_kls] |
|
|
|
return internal_states, logits, log_probs, entropies, self_kls |
|
|