# Copyright 2017 Google, Inc. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # ============================================================================== """A base class definition for trainable optimizers.""" from __future__ import absolute_import from __future__ import division from __future__ import print_function import collections import itertools import tensorflow as tf from tensorflow.python.framework import tensor_shape OPTIMIZER_SCOPE = "LOL" _LOCAL_VARIABLE_PREFIX = "local_state_" _LOCAL_STATE_VARIABLE_COLLECTION = "local_state_collection" EPSILON = 1e-6 class TrainableOptimizer(tf.train.Optimizer): """Base class for trainable optimizers. A trainable optimizer is an optimizer that has parameters that can themselves be learned (meta-optimized). Subclasses must implement: _compute_update(self, param, grad, state) """ def __init__(self, name, state_keys, use_attention=False, use_log_objective=False, obj_train_max_multiplier=-1, use_second_derivatives=True, use_numerator_epsilon=False, **kwargs): """Initializes the optimizer with the given name and settings. Args: name: The name string for this optimizer. state_keys: The names of any required state variables (list) use_attention: Whether this optimizer uses attention (Default: True) use_log_objective: Whether this optimizer uses the logarithm of the objective when computing the loss (Default: False) obj_train_max_multiplier: The maximum multiplier for the increase in the objective before meta-training is stopped. If <= 0, meta-training is not stopped early. (Default: -1) use_second_derivatives: Whether this optimizer uses second derivatives in meta-training. This should be set to False if some second derivatives in the meta-training problem set are not defined in Tensorflow. (Default: True) use_numerator_epsilon: Whether to use epsilon in the numerator when scaling the problem objective during meta-training. (Default: False) **kwargs: Any additional keyword arguments. """ self.use_second_derivatives = use_second_derivatives self.state_keys = sorted(state_keys) self.use_attention = use_attention self.use_log_objective = use_log_objective self.obj_train_max_multiplier = obj_train_max_multiplier self.use_numerator_epsilon = use_numerator_epsilon use_locking = False super(TrainableOptimizer, self).__init__(use_locking, name) def _create_slots(self, var_list): """Creates all slots needed by the variables. Args: var_list: A list of `Variable` objects. """ for var in var_list: init_states = self._initialize_state(var) for slot_name in sorted(init_states): slot_var_name = "{}_{}".format(self.get_name(), slot_name) value = init_states[slot_name] self._get_or_make_slot(var, value, slot_name, slot_var_name) def _initialize_state(self, var): """Initializes any state required for this variable. Args: var: a tensor containing parameters to be optimized Returns: state: a dictionary mapping state keys to initial state values (tensors) """ return {} def _initialize_global_state(self): """Initializes any global state values.""" return [] def _apply_common(self, grad, var): """Applies the optimizer updates to the variables. Note: this should only get called via _apply_dense or _apply_sparse when using the optimizer via optimizer.minimize or optimizer.apply_gradients. During meta-training, the optimizer.train function should be used to construct an optimization path that is differentiable. Args: grad: A tensor representing the gradient. var: A tf.Variable with the same shape as grad. Returns: update_op: A tensorflow op that assigns new values to the variable, and also defines dependencies that update the state variables for the optimizer. """ state = {key: self.get_slot(var, key) for key in self.get_slot_names()} new_var, new_state = self._compute_update(var, grad, state) state_assign_ops = [tf.assign(state_var, new_state[key]) for key, state_var in state.items()] with tf.control_dependencies(state_assign_ops): update_op = var.assign(new_var) return update_op def _apply_dense(self, grad, var): """Adds ops to apply dense gradients to 'var'.""" return self._apply_common(grad, var) def _apply_sparse(self, grad, var): """Adds ops to apply sparse gradients to 'var'.""" return self._apply_common(grad, var) def _compute_update(self, param, grad, state): """Computes the update step for optimization. Args: param: A tensor of parameters to optimize. grad: The gradient tensor of the objective with respect to the parameters. (It has the same shape as param.) state: A dictionary containing any extra state required by the optimizer. Returns: updated_params: The updated parameters. updated_state: The dictionary of updated state variable(s). """ raise NotImplementedError def _compute_updates(self, params, grads, states, global_state): """Maps the compute update functions for each parameter. This function can be overriden by a subclass if the subclass wants to combine information across the different parameters in the list. Args: params: A list of parameter tensors. grads: A list of gradients corresponding to each parameter. states: A list of state variables corresponding to each parameter. global_state: A list of global state variables for the problem. Returns: new_params: The updated parameters. new_states: The updated states. new_global_state: The updated global state. attention_params: A list of attention parameters. This is the same as new_params if the optimizer does not use attention. """ # Zip up the arguments to _compute_update. args = zip(params, grads, states) # Call compute_update on each set of parameter/gradient/state args. new_params, new_states = zip(*list( itertools.starmap(self._compute_update, args))) # Global state is unused in the basic case, just pass it through. return list(new_params), list(new_states), global_state, list(new_params) def train(self, problem, dataset): """Creates graph operations to train the optimizer. Args: problem: A problem_generator.Problem instance to train on. dataset: A datasets.Dataset tuple to use when training. Returns: meta_objective: A tensorflow operation for computing the meta-objective obj_weights: A tensor placeholder for feeding in the objective weights obj_values: The subproblem objective values during optimization batches: The batch indexes tensor for overriding with feed_dict first_unroll: A placeholder signifying if this is a first unroll (this will propagate the gradients slightly differently). reset_state: A placeholder signifying that the rnn state should be reset. output_state: The final state of the optimizer init_loop_vars_to_override: Local variables that can be assigned to propagate the optimizer and problem state for unrolling final_loop_vals: Final values of the loop variables that can be assigned to init_loop_vars_to_override. """ # Placeholder for the objective weights obj_weights = tf.placeholder(tf.float32) num_iter = tf.shape(obj_weights)[0] # Unpack the dataset and generate the minibatches for training data, labels = dataset # Convert the ndarrays to tensors so we can pass them back in via feed_dict data = tf.constant(data) labels = tf.constant(labels) batches = tf.placeholder(tf.int32) first_unroll = tf.placeholder_with_default(False, []) reset_state = tf.placeholder_with_default(False, []) training_output = collections.namedtuple("TrainingOutput", ["metaobj", "obj_weights", "problem_objectives", "initial_obj", "batches", "first_unroll", "reset_state", "output_state", "init_loop_vars", "output_loop_vars"]) def loop_body(itr, obj_accum, params, attend_params, flattened_states, global_state, all_obj, unused_init_obj, data, labels, batches): """Body of the meta-training while loop for optimizing a sub-problem. Args: itr: The current meta-training iteration. obj_accum: The accumulated objective over all training steps so far. params: The parameters of the sub-problem. attend_params: The parameters of the sub-problems at the attended location. flattened_states: The states of the trainable optimizer, sorted and flattened into a list (since a while loop can't handle nested lists or dictionaries). global_state: The global state of the optimizer. all_obj: The list of all objective values in the training process. unused_init_obj: The initial objective (unused here, but needed in the variable list because it's used in a stopping condition in the loop_cond.) data: The data for this problem. labels: The labels corresponding to the data. batches: The batch indexes needed for shuffled minibatch creation. Returns: itr: The updated meta-training iteration. obj_accum: The updated accumulated objective. params: The new parameters of the sub-problem. attend_params: The new parameters of the sub-problems at the attended location. flattened_states: The new states of the trainable optimizer. global_state: The updated global state. all_obj: The updates list of all objective values. unused_init_obj: The initial objective. data: The data for this problem. labels: The labels corresponding to the data. batches: The batch indexes needed for shuffled minibatch creation. """ batch_indices = tf.gather(batches, itr) batch_data = tf.gather(data, batch_indices) batch_labels = tf.gather(labels, batch_indices) # Compute the objective over the entire dataset (full batch). obj = problem.objective(params, data, labels) # Compute the gradients on just the current batch if self.use_attention: current_obj = problem.objective(attend_params, batch_data, batch_labels) grads = problem.gradients(current_obj, attend_params) else: current_obj = problem.objective(params, batch_data, batch_labels) grads = problem.gradients(current_obj, params) if not self.use_second_derivatives: new_grads = [] for grad in grads: if isinstance(grad, tf.IndexedSlices): new_grads.append( tf.IndexedSlices(tf.stop_gradient(grad.values), grad.indices)) else: new_grads.append(tf.stop_gradient(grad)) grads = new_grads # store the objective value for the entire problem at each iteration all_obj = tf.concat([all_obj, tf.reshape(obj, (1,))], 0) # accumulate the weighted objective for the entire dataset acc = tf.gather(obj_weights, itr) * obj obj_accum = tf.add(obj_accum, acc) # Set the shape to keep the shape invariant for obj_accum. Without this, # the graph builder thinks the tensor shape is unknown on the 2nd iter. obj_accum.set_shape([]) # convert flattened_states to dictionaries dict_states = [dict(zip(self.state_keys, flat_state)) for flat_state in flattened_states] # compute the new parameters and states args = (params, grads, dict_states, global_state) updates = self._compute_updates(*args) new_params, new_states, new_global_state, new_attend_params = updates # flatten the states new_flattened_states = map(flatten_and_sort, new_states) return [itr + 1, obj_accum, new_params, new_attend_params, new_flattened_states, new_global_state, all_obj, unused_init_obj, data, labels, batches] def loop_cond(itr, obj_accum, unused_params, unused_attend_params, unused_flattened_states, unused_global_state, all_obj, init_obj, *args): """Termination conditions of the sub-problem optimization loop.""" del args # unused cond1 = tf.less(itr, num_iter) # We've run < num_iter times cond2 = tf.is_finite(obj_accum) # The objective is still finite if self.obj_train_max_multiplier > 0: current_obj = tf.gather(all_obj, itr) # Account for negative init_obj too max_diff = (self.obj_train_max_multiplier - 1) * tf.abs(init_obj) max_obj = init_obj + max_diff # The objective is a reasonable multiplier of the original objective cond3 = tf.less(current_obj, max_obj) return tf.logical_and(tf.logical_and(cond1, cond2), cond3, name="training_loop_cond") else: return tf.logical_and(cond1, cond2, name="training_loop_cond") init = self._initialize_training_loop_parameters( problem, data, labels, batches, first_unroll, reset_state) loop_vars, invariants, initial_obj, init_loop_vars_to_override = init loop_output = tf.while_loop(loop_cond, loop_body, loop_vars, swap_memory=True, shape_invariants=invariants) meta_obj, problem_objectives = loop_output[1], loop_output[6] # The meta objective is normalized by the initial objective at the start of # the series of partial unrolls. scaled_meta_objective = self.scale_objective( meta_obj, problem_objectives, initial_obj) final_loop_vals = ( [initial_obj] + loop_output[2] + loop_output[3] + loop_output[5]) final_loop_vals.extend(itertools.chain(*loop_output[4])) return training_output(scaled_meta_objective, obj_weights, problem_objectives, initial_obj, batches, first_unroll, reset_state, loop_output[4], init_loop_vars_to_override, final_loop_vals) def _initialize_training_loop_parameters( self, problem, data, labels, batches, first_unroll, reset_state): """Initializes the vars and params needed for the training process. Args: problem: The problem being optimized. data: The data for the problem. labels: The corresponding labels for the data. batches: The indexes needed to create shuffled batches of the data. first_unroll: Whether this is the first unroll in a partial unrolling. reset_state: Whether RNN state variables should be reset. Returns: loop_vars: The while loop variables for training. invariants: The corresponding variable shapes (required by while loop). initial_obj: The initial objective (used later for scaling). init_loop_vars_to_override: The loop vars that can be overridden when performing training via partial unrolls. """ # Extract these separately so we don't have to make inter-variable # dependencies. initial_tensors = problem.init_tensors() return_initial_tensor_values = first_unroll initial_params_vars, initial_params = local_state_variables( initial_tensors, return_initial_tensor_values) initial_attend_params_vars, initial_attend_params = local_state_variables( initial_tensors, return_initial_tensor_values) # Recalculate the initial objective for the list on each partial unroll with # the new initial_params. initial_obj holds the value from the very first # unroll. initial_obj_init = problem.objective(initial_params, data, labels) return_initial_obj_init = first_unroll [initial_obj_var], [initial_obj] = local_state_variables( [initial_obj_init], return_initial_obj_init) # Initialize the loop variables. initial_itr = tf.constant(0, dtype=tf.int32) initial_meta_obj = tf.constant(0, dtype=tf.float32) # N.B. the use of initial_obj_init here rather than initial_obj initial_problem_objectives = tf.reshape(initial_obj_init, (1,)) # Initialize the extra state. initial_state_vars = [] initial_state = [] state_shapes = [] return_initial_state_values = reset_state for param in initial_tensors: param_state_vars, param_state = local_state_variables( flatten_and_sort(self._initialize_state(param)), return_initial_state_values) initial_state_vars.append(param_state_vars) initial_state.append(param_state) state_shapes.append([f.get_shape() for f in param_state]) # Initialize any global (problem-level) state. initial_global_state_vars, initial_global_state = local_state_variables( self._initialize_global_state(), return_initial_state_values) global_shapes = [] for item in initial_global_state: global_shapes.append(item.get_shape()) # build the list of loop variables: loop_vars = [ initial_itr, initial_meta_obj, initial_params, # Local variables. initial_attend_params, # Local variables. initial_state, # Local variables. initial_global_state, # Local variables. initial_problem_objectives, initial_obj, # Local variable. data, labels, batches, ] invariants = [ initial_itr.get_shape(), initial_meta_obj.get_shape(), [t.get_shape() for t in initial_params], [t.get_shape() for t in initial_attend_params], state_shapes, global_shapes, tensor_shape.TensorShape([None]), # The problem objectives list grows initial_obj.get_shape(), tensor_shape.unknown_shape(), # Placeholder shapes are unknown tensor_shape.unknown_shape(), tensor_shape.unknown_shape(), ] # Initialize local variables that we will override with final tensors at the # next iter. init_loop_vars_to_override = ( [initial_obj_var] + initial_params_vars + initial_attend_params_vars + initial_global_state_vars) init_loop_vars_to_override.extend(itertools.chain(*initial_state_vars)) return loop_vars, invariants, initial_obj, init_loop_vars_to_override def scale_objective(self, total_obj, all_objs, initial_obj, obj_scale_eps=1e-6): """Normalizes the objective based on the initial objective value. Args: total_obj: The total accumulated objective over the training run. all_objs: A list of all the individual objectives over the training run. initial_obj: The initial objective value. obj_scale_eps: The epsilon value to use in computations for stability. Returns: The scaled objective as a single value. """ if self.use_log_objective: if self.use_numerator_epsilon: scaled_problem_obj = ((all_objs + obj_scale_eps) / (initial_obj + obj_scale_eps)) log_scaled_problem_obj = tf.log(scaled_problem_obj) else: scaled_problem_obj = all_objs / (initial_obj + obj_scale_eps) log_scaled_problem_obj = tf.log(scaled_problem_obj + obj_scale_eps) return tf.reduce_mean(log_scaled_problem_obj) else: return total_obj / (initial_obj + obj_scale_eps) def local_state_variables(init_values, return_init_values): """Create local variables initialized from init_values. This will create local variables from a list of init_values. Each variable will be named based on the value's shape and dtype. As a convenience, a boolean tensor allows you to return value from the created local variable or from the original init value. Args: init_values: iterable of tensors return_init_values: boolean tensor Returns: local_vars: list of the created local variables. vals: if return_init_values is true, then this returns the values of init_values. Otherwise it returns the values of the local_vars. """ if not init_values: return [], [] # This generates a harmless warning when saving the metagraph. variable_use_count = tf.get_collection_ref(_LOCAL_STATE_VARIABLE_COLLECTION) if not variable_use_count: variable_use_count.append(collections.defaultdict(int)) variable_use_count = variable_use_count[0] local_vars = [] with tf.variable_scope(OPTIMIZER_SCOPE): # We can't use the init_value as an initializer as init_value may # itself depend on some problem variables. This would produce # inter-variable initialization order dependence which TensorFlow # sucks at making easy. for init_value in init_values: name = create_local_state_variable_name(init_value) unique_name = name + "_" + str(variable_use_count[name]) variable_use_count[name] += 1 # The overarching idea here is to be able to reuse variables between # different sessions on the same TensorFlow master without errors. By # uniquifying based on the type and name we mirror the checks made inside # TensorFlow, while still allowing some memory reuse. Ultimately this is a # hack due to the broken Session.reset(). local_vars.append( tf.get_local_variable( unique_name, initializer=tf.zeros( init_value.get_shape(), dtype=init_value.dtype))) # It makes things a lot simpler if we use the init_value the first # iteration, instead of the variable itself. It allows us to propagate # gradients through it as well as simplifying initialization. The variable # ends up assigned to after the first iteration. vals = tf.cond(return_init_values, lambda: init_values, lambda: local_vars) if len(init_values) == 1: # tf.cond extracts elements from singleton lists. vals = [vals] return local_vars, vals def create_local_state_variable_name(tensor): """Create a name of the variable based on its type and shape.""" if not tensor.get_shape().is_fully_defined(): raise ValueError("Need a fully specified shape to create a local variable.") return (_LOCAL_VARIABLE_PREFIX + "_".join( map(str, tensor.get_shape().as_list())) + "_" + tensor.dtype.name) def is_local_state_variable(op): """Returns if this op is a local state variable created for training.""" return op.node_def.op in ["Variable", "VariableV2"] and op.name.startswith( OPTIMIZER_SCOPE + "/" + _LOCAL_VARIABLE_PREFIX) def flatten_and_sort(dictionary): """Flattens a dictionary into a list of values sorted by the keys.""" return [dictionary[k] for k in sorted(dictionary.keys())]