|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""A base class definition for trainable optimizers.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import collections |
|
import itertools |
|
|
|
import tensorflow as tf |
|
|
|
from tensorflow.python.framework import tensor_shape |
|
|
|
OPTIMIZER_SCOPE = "LOL" |
|
_LOCAL_VARIABLE_PREFIX = "local_state_" |
|
_LOCAL_STATE_VARIABLE_COLLECTION = "local_state_collection" |
|
EPSILON = 1e-6 |
|
|
|
|
|
class TrainableOptimizer(tf.train.Optimizer): |
|
"""Base class for trainable optimizers. |
|
|
|
A trainable optimizer is an optimizer that has parameters that can themselves |
|
be learned (meta-optimized). |
|
|
|
Subclasses must implement: |
|
_compute_update(self, param, grad, state) |
|
""" |
|
|
|
def __init__(self, name, state_keys, use_attention=False, |
|
use_log_objective=False, obj_train_max_multiplier=-1, |
|
use_second_derivatives=True, use_numerator_epsilon=False, |
|
**kwargs): |
|
"""Initializes the optimizer with the given name and settings. |
|
|
|
Args: |
|
name: The name string for this optimizer. |
|
state_keys: The names of any required state variables (list) |
|
use_attention: Whether this optimizer uses attention (Default: True) |
|
use_log_objective: Whether this optimizer uses the logarithm of the |
|
objective when computing the loss (Default: False) |
|
obj_train_max_multiplier: The maximum multiplier for the increase in the |
|
objective before meta-training is stopped. If <= 0, meta-training is |
|
not stopped early. (Default: -1) |
|
use_second_derivatives: Whether this optimizer uses second derivatives in |
|
meta-training. This should be set to False if some second derivatives |
|
in the meta-training problem set are not defined in Tensorflow. |
|
(Default: True) |
|
use_numerator_epsilon: Whether to use epsilon in the numerator when |
|
scaling the problem objective during meta-training. (Default: False) |
|
**kwargs: Any additional keyword arguments. |
|
""" |
|
self.use_second_derivatives = use_second_derivatives |
|
self.state_keys = sorted(state_keys) |
|
self.use_attention = use_attention |
|
self.use_log_objective = use_log_objective |
|
self.obj_train_max_multiplier = obj_train_max_multiplier |
|
self.use_numerator_epsilon = use_numerator_epsilon |
|
|
|
use_locking = False |
|
super(TrainableOptimizer, self).__init__(use_locking, name) |
|
|
|
def _create_slots(self, var_list): |
|
"""Creates all slots needed by the variables. |
|
|
|
Args: |
|
var_list: A list of `Variable` objects. |
|
""" |
|
for var in var_list: |
|
init_states = self._initialize_state(var) |
|
for slot_name in sorted(init_states): |
|
slot_var_name = "{}_{}".format(self.get_name(), slot_name) |
|
value = init_states[slot_name] |
|
self._get_or_make_slot(var, value, slot_name, slot_var_name) |
|
|
|
def _initialize_state(self, var): |
|
"""Initializes any state required for this variable. |
|
|
|
Args: |
|
var: a tensor containing parameters to be optimized |
|
|
|
Returns: |
|
state: a dictionary mapping state keys to initial state values (tensors) |
|
""" |
|
return {} |
|
|
|
def _initialize_global_state(self): |
|
"""Initializes any global state values.""" |
|
return [] |
|
|
|
def _apply_common(self, grad, var): |
|
"""Applies the optimizer updates to the variables. |
|
|
|
Note: this should only get called via _apply_dense or _apply_sparse when |
|
using the optimizer via optimizer.minimize or optimizer.apply_gradients. |
|
During meta-training, the optimizer.train function should be used to |
|
construct an optimization path that is differentiable. |
|
|
|
Args: |
|
grad: A tensor representing the gradient. |
|
var: A tf.Variable with the same shape as grad. |
|
|
|
Returns: |
|
update_op: A tensorflow op that assigns new values to the variable, and |
|
also defines dependencies that update the state variables for the |
|
optimizer. |
|
""" |
|
state = {key: self.get_slot(var, key) for key in self.get_slot_names()} |
|
new_var, new_state = self._compute_update(var, grad, state) |
|
state_assign_ops = [tf.assign(state_var, new_state[key]) |
|
for key, state_var in state.items()] |
|
with tf.control_dependencies(state_assign_ops): |
|
update_op = var.assign(new_var) |
|
|
|
return update_op |
|
|
|
def _apply_dense(self, grad, var): |
|
"""Adds ops to apply dense gradients to 'var'.""" |
|
return self._apply_common(grad, var) |
|
|
|
def _apply_sparse(self, grad, var): |
|
"""Adds ops to apply sparse gradients to 'var'.""" |
|
return self._apply_common(grad, var) |
|
|
|
def _compute_update(self, param, grad, state): |
|
"""Computes the update step for optimization. |
|
|
|
Args: |
|
param: A tensor of parameters to optimize. |
|
grad: The gradient tensor of the objective with respect to the parameters. |
|
(It has the same shape as param.) |
|
state: A dictionary containing any extra state required by the optimizer. |
|
|
|
Returns: |
|
updated_params: The updated parameters. |
|
updated_state: The dictionary of updated state variable(s). |
|
""" |
|
raise NotImplementedError |
|
|
|
def _compute_updates(self, params, grads, states, global_state): |
|
"""Maps the compute update functions for each parameter. |
|
|
|
This function can be overriden by a subclass if the subclass wants to |
|
combine information across the different parameters in the list. |
|
|
|
Args: |
|
params: A list of parameter tensors. |
|
grads: A list of gradients corresponding to each parameter. |
|
states: A list of state variables corresponding to each parameter. |
|
global_state: A list of global state variables for the problem. |
|
|
|
Returns: |
|
new_params: The updated parameters. |
|
new_states: The updated states. |
|
new_global_state: The updated global state. |
|
attention_params: A list of attention parameters. This is the same as |
|
new_params if the optimizer does not use attention. |
|
""" |
|
|
|
args = zip(params, grads, states) |
|
|
|
|
|
new_params, new_states = zip(*list( |
|
itertools.starmap(self._compute_update, args))) |
|
|
|
|
|
return list(new_params), list(new_states), global_state, list(new_params) |
|
|
|
def train(self, problem, dataset): |
|
"""Creates graph operations to train the optimizer. |
|
|
|
Args: |
|
problem: A problem_generator.Problem instance to train on. |
|
dataset: A datasets.Dataset tuple to use when training. |
|
|
|
Returns: |
|
meta_objective: A tensorflow operation for computing the meta-objective |
|
obj_weights: A tensor placeholder for feeding in the objective weights |
|
obj_values: The subproblem objective values during optimization |
|
batches: The batch indexes tensor for overriding with feed_dict |
|
first_unroll: A placeholder signifying if this is a first unroll |
|
(this will propagate the gradients slightly differently). |
|
reset_state: A placeholder signifying that the rnn state should be reset. |
|
output_state: The final state of the optimizer |
|
init_loop_vars_to_override: Local variables that can be assigned to |
|
propagate the optimizer and problem state for unrolling |
|
final_loop_vals: Final values of the loop variables that can be |
|
assigned to init_loop_vars_to_override. |
|
""" |
|
|
|
|
|
obj_weights = tf.placeholder(tf.float32) |
|
num_iter = tf.shape(obj_weights)[0] |
|
|
|
|
|
data, labels = dataset |
|
|
|
data = tf.constant(data) |
|
labels = tf.constant(labels) |
|
batches = tf.placeholder(tf.int32) |
|
first_unroll = tf.placeholder_with_default(False, []) |
|
reset_state = tf.placeholder_with_default(False, []) |
|
|
|
training_output = collections.namedtuple("TrainingOutput", |
|
["metaobj", |
|
"obj_weights", |
|
"problem_objectives", |
|
"initial_obj", |
|
"batches", |
|
"first_unroll", |
|
"reset_state", |
|
"output_state", |
|
"init_loop_vars", |
|
"output_loop_vars"]) |
|
|
|
def loop_body(itr, obj_accum, params, attend_params, flattened_states, |
|
global_state, all_obj, unused_init_obj, data, |
|
labels, batches): |
|
"""Body of the meta-training while loop for optimizing a sub-problem. |
|
|
|
Args: |
|
itr: The current meta-training iteration. |
|
obj_accum: The accumulated objective over all training steps so far. |
|
params: The parameters of the sub-problem. |
|
attend_params: The parameters of the sub-problems at the attended |
|
location. |
|
flattened_states: The states of the trainable optimizer, sorted and |
|
flattened into a list (since a while loop can't handle nested lists |
|
or dictionaries). |
|
global_state: The global state of the optimizer. |
|
all_obj: The list of all objective values in the training process. |
|
unused_init_obj: The initial objective (unused here, but needed in the |
|
variable list because it's used in a stopping condition in the |
|
loop_cond.) |
|
data: The data for this problem. |
|
labels: The labels corresponding to the data. |
|
batches: The batch indexes needed for shuffled minibatch creation. |
|
|
|
Returns: |
|
itr: The updated meta-training iteration. |
|
obj_accum: The updated accumulated objective. |
|
params: The new parameters of the sub-problem. |
|
attend_params: The new parameters of the sub-problems at the attended |
|
location. |
|
flattened_states: The new states of the trainable optimizer. |
|
global_state: The updated global state. |
|
all_obj: The updates list of all objective values. |
|
unused_init_obj: The initial objective. |
|
data: The data for this problem. |
|
labels: The labels corresponding to the data. |
|
batches: The batch indexes needed for shuffled minibatch creation. |
|
""" |
|
batch_indices = tf.gather(batches, itr) |
|
batch_data = tf.gather(data, batch_indices) |
|
batch_labels = tf.gather(labels, batch_indices) |
|
|
|
|
|
obj = problem.objective(params, data, labels) |
|
|
|
|
|
if self.use_attention: |
|
current_obj = problem.objective(attend_params, batch_data, batch_labels) |
|
grads = problem.gradients(current_obj, attend_params) |
|
else: |
|
current_obj = problem.objective(params, batch_data, batch_labels) |
|
grads = problem.gradients(current_obj, params) |
|
|
|
if not self.use_second_derivatives: |
|
new_grads = [] |
|
for grad in grads: |
|
if isinstance(grad, tf.IndexedSlices): |
|
new_grads.append( |
|
tf.IndexedSlices(tf.stop_gradient(grad.values), grad.indices)) |
|
else: |
|
new_grads.append(tf.stop_gradient(grad)) |
|
grads = new_grads |
|
|
|
|
|
all_obj = tf.concat([all_obj, tf.reshape(obj, (1,))], 0) |
|
|
|
|
|
acc = tf.gather(obj_weights, itr) * obj |
|
|
|
obj_accum = tf.add(obj_accum, acc) |
|
|
|
|
|
obj_accum.set_shape([]) |
|
|
|
|
|
dict_states = [dict(zip(self.state_keys, flat_state)) |
|
for flat_state in flattened_states] |
|
|
|
|
|
args = (params, grads, dict_states, global_state) |
|
updates = self._compute_updates(*args) |
|
new_params, new_states, new_global_state, new_attend_params = updates |
|
|
|
|
|
new_flattened_states = map(flatten_and_sort, new_states) |
|
|
|
return [itr + 1, obj_accum, new_params, new_attend_params, |
|
new_flattened_states, new_global_state, all_obj, unused_init_obj, |
|
data, labels, batches] |
|
|
|
def loop_cond(itr, obj_accum, unused_params, unused_attend_params, |
|
unused_flattened_states, unused_global_state, all_obj, |
|
init_obj, *args): |
|
"""Termination conditions of the sub-problem optimization loop.""" |
|
del args |
|
|
|
cond1 = tf.less(itr, num_iter) |
|
cond2 = tf.is_finite(obj_accum) |
|
|
|
if self.obj_train_max_multiplier > 0: |
|
current_obj = tf.gather(all_obj, itr) |
|
|
|
max_diff = (self.obj_train_max_multiplier - 1) * tf.abs(init_obj) |
|
max_obj = init_obj + max_diff |
|
|
|
cond3 = tf.less(current_obj, max_obj) |
|
|
|
return tf.logical_and(tf.logical_and(cond1, cond2), cond3, |
|
name="training_loop_cond") |
|
else: |
|
return tf.logical_and(cond1, cond2, name="training_loop_cond") |
|
|
|
init = self._initialize_training_loop_parameters( |
|
problem, data, labels, batches, first_unroll, reset_state) |
|
loop_vars, invariants, initial_obj, init_loop_vars_to_override = init |
|
|
|
loop_output = tf.while_loop(loop_cond, loop_body, loop_vars, |
|
swap_memory=True, shape_invariants=invariants) |
|
meta_obj, problem_objectives = loop_output[1], loop_output[6] |
|
|
|
|
|
|
|
scaled_meta_objective = self.scale_objective( |
|
meta_obj, problem_objectives, initial_obj) |
|
|
|
final_loop_vals = ( |
|
[initial_obj] + loop_output[2] + loop_output[3] + loop_output[5]) |
|
final_loop_vals.extend(itertools.chain(*loop_output[4])) |
|
|
|
return training_output(scaled_meta_objective, |
|
obj_weights, |
|
problem_objectives, |
|
initial_obj, |
|
batches, |
|
first_unroll, |
|
reset_state, |
|
loop_output[4], |
|
init_loop_vars_to_override, |
|
final_loop_vals) |
|
|
|
def _initialize_training_loop_parameters( |
|
self, problem, data, labels, batches, first_unroll, reset_state): |
|
"""Initializes the vars and params needed for the training process. |
|
|
|
Args: |
|
problem: The problem being optimized. |
|
data: The data for the problem. |
|
labels: The corresponding labels for the data. |
|
batches: The indexes needed to create shuffled batches of the data. |
|
first_unroll: Whether this is the first unroll in a partial unrolling. |
|
reset_state: Whether RNN state variables should be reset. |
|
|
|
Returns: |
|
loop_vars: The while loop variables for training. |
|
invariants: The corresponding variable shapes (required by while loop). |
|
initial_obj: The initial objective (used later for scaling). |
|
init_loop_vars_to_override: The loop vars that can be overridden when |
|
performing training via partial unrolls. |
|
""" |
|
|
|
|
|
initial_tensors = problem.init_tensors() |
|
|
|
return_initial_tensor_values = first_unroll |
|
initial_params_vars, initial_params = local_state_variables( |
|
initial_tensors, return_initial_tensor_values) |
|
initial_attend_params_vars, initial_attend_params = local_state_variables( |
|
initial_tensors, return_initial_tensor_values) |
|
|
|
|
|
|
|
initial_obj_init = problem.objective(initial_params, data, labels) |
|
return_initial_obj_init = first_unroll |
|
[initial_obj_var], [initial_obj] = local_state_variables( |
|
[initial_obj_init], return_initial_obj_init) |
|
|
|
|
|
initial_itr = tf.constant(0, dtype=tf.int32) |
|
initial_meta_obj = tf.constant(0, dtype=tf.float32) |
|
|
|
initial_problem_objectives = tf.reshape(initial_obj_init, (1,)) |
|
|
|
|
|
initial_state_vars = [] |
|
initial_state = [] |
|
state_shapes = [] |
|
return_initial_state_values = reset_state |
|
for param in initial_tensors: |
|
param_state_vars, param_state = local_state_variables( |
|
flatten_and_sort(self._initialize_state(param)), |
|
return_initial_state_values) |
|
|
|
initial_state_vars.append(param_state_vars) |
|
initial_state.append(param_state) |
|
state_shapes.append([f.get_shape() for f in param_state]) |
|
|
|
|
|
initial_global_state_vars, initial_global_state = local_state_variables( |
|
self._initialize_global_state(), return_initial_state_values) |
|
|
|
global_shapes = [] |
|
for item in initial_global_state: |
|
global_shapes.append(item.get_shape()) |
|
|
|
|
|
loop_vars = [ |
|
initial_itr, |
|
initial_meta_obj, |
|
initial_params, |
|
initial_attend_params, |
|
initial_state, |
|
initial_global_state, |
|
initial_problem_objectives, |
|
initial_obj, |
|
data, |
|
labels, |
|
batches, |
|
] |
|
|
|
invariants = [ |
|
initial_itr.get_shape(), |
|
initial_meta_obj.get_shape(), |
|
[t.get_shape() for t in initial_params], |
|
[t.get_shape() for t in initial_attend_params], |
|
state_shapes, |
|
global_shapes, |
|
tensor_shape.TensorShape([None]), |
|
initial_obj.get_shape(), |
|
tensor_shape.unknown_shape(), |
|
tensor_shape.unknown_shape(), |
|
tensor_shape.unknown_shape(), |
|
] |
|
|
|
|
|
|
|
init_loop_vars_to_override = ( |
|
[initial_obj_var] + initial_params_vars + initial_attend_params_vars + |
|
initial_global_state_vars) |
|
init_loop_vars_to_override.extend(itertools.chain(*initial_state_vars)) |
|
|
|
return loop_vars, invariants, initial_obj, init_loop_vars_to_override |
|
|
|
def scale_objective(self, total_obj, all_objs, initial_obj, |
|
obj_scale_eps=1e-6): |
|
"""Normalizes the objective based on the initial objective value. |
|
|
|
Args: |
|
total_obj: The total accumulated objective over the training run. |
|
all_objs: A list of all the individual objectives over the training run. |
|
initial_obj: The initial objective value. |
|
obj_scale_eps: The epsilon value to use in computations for stability. |
|
|
|
Returns: |
|
The scaled objective as a single value. |
|
""" |
|
if self.use_log_objective: |
|
if self.use_numerator_epsilon: |
|
scaled_problem_obj = ((all_objs + obj_scale_eps) / |
|
(initial_obj + obj_scale_eps)) |
|
log_scaled_problem_obj = tf.log(scaled_problem_obj) |
|
else: |
|
scaled_problem_obj = all_objs / (initial_obj + obj_scale_eps) |
|
log_scaled_problem_obj = tf.log(scaled_problem_obj + obj_scale_eps) |
|
return tf.reduce_mean(log_scaled_problem_obj) |
|
else: |
|
return total_obj / (initial_obj + obj_scale_eps) |
|
|
|
|
|
def local_state_variables(init_values, return_init_values): |
|
"""Create local variables initialized from init_values. |
|
|
|
This will create local variables from a list of init_values. Each variable |
|
will be named based on the value's shape and dtype. |
|
|
|
As a convenience, a boolean tensor allows you to return value from |
|
the created local variable or from the original init value. |
|
|
|
Args: |
|
init_values: iterable of tensors |
|
return_init_values: boolean tensor |
|
|
|
Returns: |
|
local_vars: list of the created local variables. |
|
vals: if return_init_values is true, then this returns the values of |
|
init_values. Otherwise it returns the values of the local_vars. |
|
""" |
|
if not init_values: |
|
return [], [] |
|
|
|
|
|
variable_use_count = tf.get_collection_ref(_LOCAL_STATE_VARIABLE_COLLECTION) |
|
if not variable_use_count: |
|
variable_use_count.append(collections.defaultdict(int)) |
|
variable_use_count = variable_use_count[0] |
|
|
|
local_vars = [] |
|
with tf.variable_scope(OPTIMIZER_SCOPE): |
|
|
|
|
|
|
|
|
|
for init_value in init_values: |
|
name = create_local_state_variable_name(init_value) |
|
unique_name = name + "_" + str(variable_use_count[name]) |
|
variable_use_count[name] += 1 |
|
|
|
|
|
|
|
|
|
|
|
local_vars.append( |
|
tf.get_local_variable( |
|
unique_name, |
|
initializer=tf.zeros( |
|
init_value.get_shape(), dtype=init_value.dtype))) |
|
|
|
|
|
|
|
|
|
|
|
vals = tf.cond(return_init_values, lambda: init_values, lambda: local_vars) |
|
if len(init_values) == 1: |
|
|
|
vals = [vals] |
|
return local_vars, vals |
|
|
|
|
|
def create_local_state_variable_name(tensor): |
|
"""Create a name of the variable based on its type and shape.""" |
|
if not tensor.get_shape().is_fully_defined(): |
|
raise ValueError("Need a fully specified shape to create a local variable.") |
|
|
|
return (_LOCAL_VARIABLE_PREFIX + "_".join( |
|
map(str, tensor.get_shape().as_list())) + "_" + tensor.dtype.name) |
|
|
|
|
|
def is_local_state_variable(op): |
|
"""Returns if this op is a local state variable created for training.""" |
|
return op.node_def.op in ["Variable", "VariableV2"] and op.name.startswith( |
|
OPTIMIZER_SCOPE + "/" + _LOCAL_VARIABLE_PREFIX) |
|
|
|
|
|
def flatten_and_sort(dictionary): |
|
"""Flattens a dictionary into a list of values sorted by the keys.""" |
|
return [dictionary[k] for k in sorted(dictionary.keys())] |
|
|