|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
"""Utilities. |
|
""" |
|
|
|
from __future__ import absolute_import |
|
from __future__ import division |
|
from __future__ import print_function |
|
|
|
import contextlib |
|
import tensorflow as tf |
|
import sonnet as snt |
|
import itertools |
|
import functools |
|
|
|
from tensorflow.core.framework import node_def_pb2 |
|
from tensorflow.python.framework import device as pydev |
|
from tensorflow.python.framework import errors |
|
from tensorflow.python.ops import variable_scope as variable_scope_ops |
|
from sonnet.python.modules import util as snt_util |
|
|
|
from tensorflow.python.util import nest |
|
|
|
|
|
def eqzip(*args): |
|
"""Zip but raises error if lengths don't match. |
|
|
|
Args: |
|
*args: list of lists or tuples |
|
Returns: |
|
list: the result of zip |
|
Raises: |
|
ValueError: when the lengths don't match |
|
""" |
|
|
|
sizes = [len(x) for x in args] |
|
if not all([sizes[0] == x for x in sizes]): |
|
raise ValueError("Lists are of different sizes. \n %s"%str(sizes)) |
|
return zip(*args) |
|
|
|
|
|
@contextlib.contextmanager |
|
def assert_no_new_variables(): |
|
"""Ensure that no tf.Variables are constructed inside the context. |
|
|
|
Yields: |
|
None |
|
Raises: |
|
ValueError: if there is a variable created. |
|
""" |
|
num_vars = len(tf.global_variables()) |
|
old_variables = tf.global_variables() |
|
yield |
|
if len(tf.global_variables()) != num_vars: |
|
new_vars = set(tf.global_variables()) - set(old_variables) |
|
tf.logging.error("NEW VARIABLES CREATED") |
|
tf.logging.error(10*"=") |
|
for v in new_vars: |
|
tf.logging.error(v) |
|
|
|
raise ValueError("Variables created inside an " |
|
"assert_no_new_variables context") |
|
if old_variables != tf.global_variables(): |
|
raise ValueError("Variables somehow changed inside an " |
|
"assert_no_new_variables context." |
|
"This means something modified the tf.global_variables()") |
|
|
|
|
|
def get_variables_in_modules(module_list): |
|
var_list = [] |
|
for m in module_list: |
|
var_list.extend(snt.get_variables_in_module(m)) |
|
return var_list |
|
|
|
|
|
def state_barrier_context(state): |
|
"""Return a context manager that prevents interior ops from running |
|
unless the whole state has been computed. |
|
|
|
This is to prevent assign race conditions. |
|
""" |
|
tensors = [x for x in nest.flatten(state) if type(x) == tf.Tensor] |
|
tarray = [x.flow for x in nest.flatten(state) if hasattr(x, "flow")] |
|
return tf.control_dependencies(tensors + tarray) |
|
|
|
|
|
def _identity_fn(tf_entity): |
|
if hasattr(tf_entity, "identity"): |
|
return tf_entity.identity() |
|
else: |
|
return tf.identity(tf_entity) |
|
|
|
|
|
def state_barrier_result(state): |
|
"""Return the same state, but with a control dependency to prevent it from |
|
being partially computed |
|
""" |
|
with state_barrier_context(state): |
|
return nest.map_structure(_identity_fn, state) |
|
|
|
|
|
def train_iterator(num_iterations): |
|
"""Iterator that returns an index of the current step. |
|
This iterator runs forever if num_iterations is None |
|
otherwise it runs for some fixed amount of steps. |
|
""" |
|
if num_iterations is None: |
|
return itertools.count() |
|
else: |
|
return xrange(num_iterations) |
|
|
|
|
|
def print_op(op, msg): |
|
"""Print a string and return an op wrapped in a control dependency to make |
|
sure it ran.""" |
|
print_op = tf.Print(tf.constant(0), [tf.constant(0)], msg) |
|
return tf.group(op, print_op) |
|
|
|
|
|
class MultiQueueRunner(tf.train.QueueRunner): |
|
"""A QueueRunner with multiple queues """ |
|
def __init__(self, queues, enqueue_ops): |
|
close_op = tf.group(* [q.close() for q in queues]) |
|
cancel_op = tf.group( |
|
* [q.close(cancel_pending_enqueues=True) for q in queues]) |
|
queue_closed_exception_types = (errors.OutOfRangeError,) |
|
|
|
enqueue_op = tf.group(*enqueue_ops, name="multi_enqueue") |
|
|
|
super(MultiQueueRunner, self).__init__( |
|
queues[0], |
|
enqueue_ops=[enqueue_op], |
|
close_op=close_op, |
|
cancel_op=cancel_op, |
|
queue_closed_exception_types=queue_closed_exception_types) |
|
|
|
|
|
|
|
|
|
|
|
def sample_n_per_class(dataset, samples_per_class): |
|
"""Create a new callable / dataset object that returns batches of each with |
|
samples_per_class per label. |
|
|
|
Args: |
|
dataset: fn |
|
samples_per_class: int |
|
Returns: |
|
function, [] -> batch where batch is the same type as the return of |
|
dataset(). |
|
""" |
|
|
|
with tf.control_dependencies(None), tf.name_scope(None): |
|
with tf.name_scope("queue_runner/sample_n_per_class"): |
|
batch = dataset() |
|
num_classes = batch.label_onehot.shape.as_list()[1] |
|
batch_size = num_classes * samples_per_class |
|
|
|
flatten = nest.flatten(batch) |
|
queues = [] |
|
enqueue_ops = [] |
|
capacity = samples_per_class * 20 |
|
for i in xrange(num_classes): |
|
queue = tf.FIFOQueue( |
|
capacity=capacity, |
|
shapes=[f.shape.as_list()[1:] for f in flatten], |
|
dtypes=[f.dtype for f in flatten]) |
|
queues.append(queue) |
|
|
|
idx = tf.where(tf.equal(batch.label, i)) |
|
sub_batch = [] |
|
to_enqueue = [] |
|
for elem in batch: |
|
new_e = tf.gather(elem, idx) |
|
new_e = tf.squeeze(new_e, 1) |
|
to_enqueue.append(new_e) |
|
|
|
remaining = (capacity - queue.size()) |
|
to_add = tf.minimum(tf.shape(idx)[0], remaining) |
|
|
|
def _enqueue(): |
|
return queue.enqueue_many([t[:to_add] for t in to_enqueue]) |
|
|
|
enqueue_op = tf.cond( |
|
tf.equal(to_add, 0), tf.no_op, _enqueue) |
|
enqueue_ops.append(enqueue_op) |
|
|
|
|
|
|
|
print_lam = lambda: tf.Print(tf.constant(0.0), [q.size() for q in queues], "MultiQueueRunner queues status. Has capacity %d"%capacity) |
|
some_percent_of_time = tf.less(tf.random_uniform([]), 0.0005) |
|
maybe_print = tf.cond(some_percent_of_time, print_lam, lambda: tf.constant(0.0)) |
|
with tf.control_dependencies([maybe_print]): |
|
enqueue_ops = [tf.group(e) for e in enqueue_ops] |
|
qr = MultiQueueRunner(queues=queues, enqueue_ops=enqueue_ops) |
|
tf.train.add_queue_runner(qr) |
|
|
|
def dequeue_batch(): |
|
with tf.name_scope("sample_n_per_batch/dequeue/"): |
|
entries = [] |
|
for q in queues: |
|
entries.append(q.dequeue_many(samples_per_class)) |
|
|
|
flat_batch = [tf.concat(x, 0) for x in zip(*entries)] |
|
idx = tf.random_shuffle(tf.range(batch_size)) |
|
flat_batch = [tf.gather(f, idx, axis=0) for f in flat_batch] |
|
return nest.pack_sequence_as(batch, flat_batch) |
|
|
|
return dequeue_batch |
|
|
|
def structure_map_multi(func, values): |
|
all_values = [nest.flatten(v) for v in values] |
|
rets = [] |
|
for pair in zip(*all_values): |
|
rets.append(func(pair)) |
|
return nest.pack_sequence_as(values[0], rets) |
|
|
|
def structure_map_split(func, value): |
|
vv = nest.flatten(value) |
|
rets = [] |
|
for v in vv: |
|
rets.append(func(v)) |
|
return [nest.pack_sequence_as(value, r) for r in zip(*rets)] |
|
|
|
def assign_variables(targets, values): |
|
return tf.group(*[t.assign(v) for t,v in eqzip(targets, values)], |
|
name="assign_variables") |
|
|
|
|
|
def create_variables_in_class_scope(method): |
|
"""Force the variables constructed in this class to live in the sonnet module. |
|
Wraps a method on a sonnet module. |
|
|
|
For example the following will create two different variables. |
|
``` |
|
class Mod(snt.AbstractModule): |
|
@create_variables_in_class_scope |
|
def dynamic_thing(self, input, name): |
|
return snt.Linear(name)(input) |
|
mod.dynamic_thing(x, name="module_nameA") |
|
mod.dynamic_thing(x, name="module_nameB") |
|
# reuse |
|
mod.dynamic_thing(y, name="module_nameA") |
|
``` |
|
""" |
|
@functools.wraps(method) |
|
def wrapper(obj, *args, **kwargs): |
|
def default_context_manager(reuse=None): |
|
variable_scope = obj.variable_scope |
|
return tf.variable_scope(variable_scope, reuse=reuse) |
|
|
|
variable_scope_context_manager = getattr(obj, "_enter_variable_scope", |
|
default_context_manager) |
|
graph = tf.get_default_graph() |
|
|
|
|
|
with variable_scope_context_manager() as tmp_variable_scope: |
|
variable_scope = tmp_variable_scope |
|
|
|
with variable_scope_ops._pure_variable_scope( |
|
variable_scope, reuse=tf.AUTO_REUSE) as pure_variable_scope: |
|
|
|
name_scope = variable_scope.original_name_scope |
|
if name_scope[-1] != "/": |
|
name_scope += "/" |
|
|
|
with tf.name_scope(name_scope): |
|
sub_scope = snt_util.to_snake_case(method.__name__) |
|
with tf.name_scope(sub_scope) as scope: |
|
out_ops = method(obj, *args, **kwargs) |
|
return out_ops |
|
|
|
return wrapper |
|
|
|
|