deanna-emery's picture
updates
93528c6
raw
history blame
8.03 kB
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for creating loop functions."""
from absl import logging
from orbit.utils import tpu_summaries
import tensorflow as tf, tf_keras
def create_loop_fn(step_fn):
"""Creates a loop function driven by a Python `while` loop.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. There are no constraints on the return value of the
function (except that it must be compatible with any `reduce_fn` provided
to the returned `loop_fn`).
Returns:
A loop function taking required `iterator` and `num_steps` parameters, as
well as optional `state` and `reduce_fn` parameters for accumulating state
over multiple iterations of the loop. See the `loop_fn` definition below for
additional details.
"""
def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Additionally, state may be accumulated across iterations of the loop.
Conceptually, state accumulation is handled roughly as follows:
for _ in range(num_steps):
step_outputs = step_fn(iterator)
state = reduce_fn(state, step_outputs)
return state
However, the implementation is slightly more complicated in order to support
looping until the iterator is exhausted (when `num_steps == -1`) and to
properly catch exceptions when running under async remote eager (as is the
case in TPU training setups involving separate coordinator/worker machines).
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. If `num_steps == -1`, will
iterate until exausting the iterator.
state: An optional initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`, or `None` if `state` and
`reduce_fn` are not provided.
"""
step = 0
try:
# To make sure the OutOfRangeError exception can be handled well under
# async remote eager, we need to wrap the loop body in `async_scope`.
with tf.experimental.async_scope():
while num_steps == -1 or step < num_steps:
outputs = step_fn(iterator)
if reduce_fn is not None:
state = reduce_fn(state, outputs)
step += 1
return state
except (StopIteration, tf.errors.OutOfRangeError):
logging.info("The dataset iterator is exhausted after %d steps.", step)
tf.experimental.async_clear_error()
return state
return loop_fn
def create_tf_while_loop_fn(step_fn):
"""Creates a loop function compatible with TF's AutoGraph loop conversion.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator` and `num_steps` parameters. If
called inside a `tf.function`, the loop will be converted by AutoGraph into
a `tf.while_loop` construct. See the `loop_fn` definition below for
additional details.
"""
def loop_fn(iterator, num_steps):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.")
for _ in tf.range(num_steps):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with tf.name_scope(""):
step_fn(iterator)
return loop_fn
def create_tf_while_loop_fn_with_state(step_fn):
"""Creates a TF while loop function with state.
This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
to be accumulated over multiple iterations of the loop. Note that the
structure of the `state` cannot be changed across iterations.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator`, `num_steps`, `state` and
`reduce_fn` parameters. If called inside a `tf.function`, the loop will be
converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
definition below for additional details.
"""
def loop_fn_with_state(iterator, num_steps, state, reduce_fn):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
state: An initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.")
def _get_relaxed_tensor_shape(t):
"""Returns a `TensorShape` with all `None` dimensions."""
if not tf.is_tensor(t):
return None
shape = t.shape
if shape.rank is not None and shape.rank > 0:
return tf.TensorShape([None] * shape.rank)
return shape
def _get_relaxed_shape_structure(s):
"""Returns the relaxed shape of the input nested structure `s`."""
return tf.nest.pack_sequence_as(
state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)])
for _ in tf.range(num_steps):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with tf.name_scope(""):
# Relax the shapes within the loop, so the shape of `state` can change
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(state, _get_relaxed_shape_structure(state))])
outputs = step_fn(iterator)
state = reduce_fn(state, outputs)
return state
return loop_fn_with_state
class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
"""Implements a two-program approach for optimizing summaries on TPU.
This version works with the result of `create_tf_while_loop_fn`.
"""
def __call__(self, iterator, num_steps):
if tf.summary.should_record_summaries():
output = self.with_summaries(iterator, tf.constant(1))
num_steps -= 1
if num_steps >= 1:
output = self.without_summaries(iterator, num_steps)
return output