|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Implementation of the ModelAdapter class.""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import mock |
|
import tensorflow as tf |
|
|
|
from learned_optimizer.problems import problem_generator as pg |
|
|
|
|
|
class ModelAdapter(pg.Problem): |
|
"""Adapts Tensorflow models/graphs into a form suitable for meta-training. |
|
|
|
This class adapts an existing TensorFlow graph into a form suitable for |
|
meta-training a learned optimizer. |
|
""" |
|
|
|
def __init__(self, make_loss_and_init_fn): |
|
"""Wraps a model in the Problem interface. |
|
|
|
make_loss_and_init argument is a callable that returns a tuple of |
|
two other callables as follows. |
|
|
|
The first will construct most of the graph and return the problem loss. It |
|
is essential that this graph contains the totality of the model's variables, |
|
but none of its queues. |
|
|
|
The second will return construct the model initialization graph given a list |
|
of parameters and return a callable that is passed an instance of |
|
tf.Session, and should initialize the models' parameters. |
|
|
|
An argument value function would look like this: |
|
|
|
```python |
|
def make_loss_and_init_fn(): |
|
inputs = queued_reader() |
|
|
|
def make_loss(): |
|
return create_model_with_variables(inputs) |
|
|
|
def make_init_fn(parameters): |
|
saver = tf.Saver(parameters) |
|
def init_fn(sess): |
|
sess.restore(sess, ...) |
|
return init_fn |
|
|
|
return make_loss, make_init_fn |
|
``` |
|
|
|
Args: |
|
make_loss_and_init_fn: a callable, as described aboce |
|
""" |
|
make_loss_fn, make_init_fn = make_loss_and_init_fn() |
|
|
|
self.make_loss_fn = make_loss_fn |
|
self.parameters, self.constants = _get_variables(make_loss_fn) |
|
|
|
if make_init_fn is not None: |
|
init_fn = make_init_fn(self.parameters + self.constants) |
|
else: |
|
init_op = tf.initialize_variables(self.parameters + self.constants) |
|
init_fn = lambda sess: sess.run(init_op) |
|
|
|
tf.logging.info("ModelAdapter parameters: %s", |
|
[op.name for op in self.parameters]) |
|
tf.logging.info("ModelAdapter constants: %s", |
|
[op.name for op in self.constants]) |
|
|
|
super(ModelAdapter, self).__init__( |
|
[], random_seed=None, noise_stdev=0.0, init_fn=init_fn) |
|
|
|
def init_tensors(self, seed=None): |
|
"""Returns a list of tensors with the given shape.""" |
|
return self.parameters |
|
|
|
def init_variables(self, seed=None): |
|
"""Returns a list of variables with the given shape.""" |
|
|
|
return self.parameters |
|
|
|
def objective(self, parameters, data=None, labels=None): |
|
"""Computes the objective given a list of parameters. |
|
|
|
Args: |
|
parameters: The parameters to optimize (as a list of tensors) |
|
data: An optional batch of data for calculating objectives |
|
labels: An optional batch of corresponding labels |
|
|
|
Returns: |
|
A scalar tensor representing the objective value |
|
""" |
|
|
|
|
|
parameter_mapping = { |
|
old_p.name: p |
|
for old_p, p in zip(self.parameters, parameters) |
|
} |
|
|
|
with tf.variable_scope(tf.get_variable_scope(), reuse=True): |
|
return _make_with_custom_variables(self.make_loss_fn, parameter_mapping) |
|
|
|
|
|
def _get_variables(func): |
|
"""Calls func, returning any variables created. |
|
|
|
The created variables are modified to not be trainable, and are placed into |
|
the LOCAL_VARIABLES collection. |
|
|
|
Args: |
|
func: Function to be called. |
|
|
|
Returns: |
|
A tuple (variables, constants) where the first element is a list of |
|
trainable variables and the second is the non-trainable variables. |
|
""" |
|
variables = [] |
|
constants = [] |
|
|
|
|
|
|
|
original_init = tf.Variable.__init__ |
|
|
|
def custom_init(self, *args, **kwargs): |
|
trainable = kwargs["trainable"] |
|
kwargs["trainable"] = False |
|
|
|
|
|
kwargs["collections"] = [tf.GraphKeys.LOCAL_VARIABLES] |
|
original_init(self, *args, **kwargs) |
|
if trainable: |
|
variables.append(self) |
|
else: |
|
constants.append(self) |
|
|
|
|
|
with tf.name_scope("unused_graph"): |
|
with mock.patch.object(tf.Variable, "__init__", custom_init): |
|
func() |
|
|
|
return variables, constants |
|
|
|
|
|
def _make_with_custom_variables(func, variable_mapping): |
|
"""Calls func and replaces the value of some variables created in it. |
|
|
|
Args: |
|
func: Function to be called. |
|
variable_mapping: A mapping of variable name to the replacement tensor or |
|
tf.Variable. |
|
|
|
Returns: |
|
The return value of func is returned. |
|
""" |
|
original_value = tf.Variable.value |
|
|
|
def custom_value(self): |
|
if self.name in variable_mapping: |
|
replacement = variable_mapping[self.name] |
|
tf.logging.info("Replaced %s with %s" % (self.name, replacement)) |
|
|
|
|
|
|
|
|
|
if isinstance(replacement, tf.Variable): |
|
replacement = original_value(replacement) |
|
|
|
return replacement |
|
else: |
|
return original_value(self) |
|
|
|
with mock.patch.object(tf.Variable, "value", custom_value): |
|
with mock.patch.object(tf.Variable, "_AsTensor", custom_value): |
|
return func() |
|
|