|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Trust region optimization. |
|
|
|
A lot of this is adapted from other's code. |
|
See Schulman's Modular RL, wojzaremba's TRPO, etc. |
|
|
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
from six.moves import xrange |
|
import tensorflow as tf |
|
import numpy as np |
|
|
|
|
|
def var_size(v): |
|
return int(np.prod([int(d) for d in v.shape])) |
|
|
|
|
|
def gradients(loss, var_list): |
|
grads = tf.gradients(loss, var_list) |
|
return [g if g is not None else tf.zeros(v.shape) |
|
for g, v in zip(grads, var_list)] |
|
|
|
def flatgrad(loss, var_list): |
|
grads = gradients(loss, var_list) |
|
return tf.concat([tf.reshape(grad, [-1]) |
|
for (v, grad) in zip(var_list, grads) |
|
if grad is not None], 0) |
|
|
|
|
|
def get_flat(var_list): |
|
return tf.concat([tf.reshape(v, [-1]) for v in var_list], 0) |
|
|
|
|
|
def set_from_flat(var_list, flat_theta): |
|
assigns = [] |
|
shapes = [v.shape for v in var_list] |
|
sizes = [var_size(v) for v in var_list] |
|
|
|
start = 0 |
|
assigns = [] |
|
for (shape, size, v) in zip(shapes, sizes, var_list): |
|
assigns.append(v.assign( |
|
tf.reshape(flat_theta[start:start + size], shape))) |
|
start += size |
|
assert start == sum(sizes) |
|
|
|
return tf.group(*assigns) |
|
|
|
|
|
class TrustRegionOptimization(object): |
|
|
|
def __init__(self, max_divergence=0.1, cg_damping=0.1): |
|
self.max_divergence = max_divergence |
|
self.cg_damping = cg_damping |
|
|
|
def setup_placeholders(self): |
|
self.flat_tangent = tf.placeholder(tf.float32, [None], 'flat_tangent') |
|
self.flat_theta = tf.placeholder(tf.float32, [None], 'flat_theta') |
|
|
|
def setup(self, var_list, raw_loss, self_divergence, |
|
divergence=None): |
|
self.setup_placeholders() |
|
|
|
self.raw_loss = raw_loss |
|
self.divergence = divergence |
|
self.loss_flat_gradient = flatgrad(raw_loss, var_list) |
|
self.divergence_gradient = gradients(self_divergence, var_list) |
|
|
|
shapes = [var.shape for var in var_list] |
|
sizes = [var_size(var) for var in var_list] |
|
|
|
start = 0 |
|
tangents = [] |
|
for shape, size in zip(shapes, sizes): |
|
param = tf.reshape(self.flat_tangent[start:start + size], shape) |
|
tangents.append(param) |
|
start += size |
|
assert start == sum(sizes) |
|
|
|
self.grad_vector_product = sum( |
|
tf.reduce_sum(g * t) for (g, t) in zip(self.divergence_gradient, tangents)) |
|
self.fisher_vector_product = flatgrad(self.grad_vector_product, var_list) |
|
|
|
self.flat_vars = get_flat(var_list) |
|
self.set_vars = set_from_flat(var_list, self.flat_theta) |
|
|
|
def optimize(self, sess, feed_dict): |
|
old_theta = sess.run(self.flat_vars) |
|
loss_flat_grad = sess.run(self.loss_flat_gradient, |
|
feed_dict=feed_dict) |
|
|
|
def calc_fisher_vector_product(tangent): |
|
feed_dict[self.flat_tangent] = tangent |
|
fvp = sess.run(self.fisher_vector_product, |
|
feed_dict=feed_dict) |
|
fvp += self.cg_damping * tangent |
|
return fvp |
|
|
|
step_dir = conjugate_gradient(calc_fisher_vector_product, -loss_flat_grad) |
|
|
|
shs = 0.5 * step_dir.dot(calc_fisher_vector_product(step_dir)) |
|
lm = np.sqrt(shs / self.max_divergence) |
|
fullstep = step_dir / lm |
|
neggdotstepdir = -loss_flat_grad.dot(step_dir) |
|
|
|
def calc_loss(theta): |
|
sess.run(self.set_vars, feed_dict={self.flat_theta: theta}) |
|
if self.divergence is None: |
|
return sess.run(self.raw_loss, feed_dict=feed_dict), True |
|
else: |
|
raw_loss, divergence = sess.run( |
|
[self.raw_loss, self.divergence], feed_dict=feed_dict) |
|
return raw_loss, divergence < self.max_divergence |
|
|
|
|
|
theta = linesearch(calc_loss, old_theta, fullstep, neggdotstepdir / lm) |
|
if self.divergence is not None: |
|
final_divergence = sess.run(self.divergence, feed_dict=feed_dict) |
|
else: |
|
final_divergence = None |
|
|
|
|
|
if final_divergence is None or final_divergence < self.max_divergence: |
|
sess.run(self.set_vars, feed_dict={self.flat_theta: theta}) |
|
else: |
|
sess.run(self.set_vars, feed_dict={self.flat_theta: old_theta}) |
|
|
|
|
|
def conjugate_gradient(f_Ax, b, cg_iters=10, residual_tol=1e-10): |
|
p = b.copy() |
|
r = b.copy() |
|
x = np.zeros_like(b) |
|
rdotr = r.dot(r) |
|
for i in xrange(cg_iters): |
|
z = f_Ax(p) |
|
v = rdotr / p.dot(z) |
|
x += v * p |
|
r -= v * z |
|
newrdotr = r.dot(r) |
|
mu = newrdotr / rdotr |
|
p = r + mu * p |
|
rdotr = newrdotr |
|
if rdotr < residual_tol: |
|
break |
|
return x |
|
|
|
|
|
def linesearch(f, x, fullstep, expected_improve_rate): |
|
accept_ratio = 0.1 |
|
max_backtracks = 10 |
|
|
|
fval, _ = f(x) |
|
for (_n_backtracks, stepfrac) in enumerate(.5 ** np.arange(max_backtracks)): |
|
xnew = x + stepfrac * fullstep |
|
newfval, valid = f(xnew) |
|
if not valid: |
|
continue |
|
actual_improve = fval - newfval |
|
expected_improve = expected_improve_rate * stepfrac |
|
ratio = actual_improve / expected_improve |
|
if ratio > accept_ratio and actual_improve > 0: |
|
return xnew |
|
|
|
return x |
|
|