Spaces:
Runtime error
Runtime error
# Copyright 2023 The Orbit Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Utilities for creating loop functions.""" | |
from absl import logging | |
from orbit.utils import tpu_summaries | |
import tensorflow as tf, tf_keras | |
def create_loop_fn(step_fn): | |
"""Creates a loop function driven by a Python `while` loop. | |
Args: | |
step_fn: A function taking a nested structure of `tf.data.Iterator` or | |
`DistributedIterator`. There are no constraints on the return value of the | |
function (except that it must be compatible with any `reduce_fn` provided | |
to the returned `loop_fn`). | |
Returns: | |
A loop function taking required `iterator` and `num_steps` parameters, as | |
well as optional `state` and `reduce_fn` parameters for accumulating state | |
over multiple iterations of the loop. See the `loop_fn` definition below for | |
additional details. | |
""" | |
def loop_fn(iterator, num_steps, state=None, reduce_fn=None): | |
"""Makes `num_steps` calls to `step_fn(iterator)`. | |
Additionally, state may be accumulated across iterations of the loop. | |
Conceptually, state accumulation is handled roughly as follows: | |
for _ in range(num_steps): | |
step_outputs = step_fn(iterator) | |
state = reduce_fn(state, step_outputs) | |
return state | |
However, the implementation is slightly more complicated in order to support | |
looping until the iterator is exhausted (when `num_steps == -1`) and to | |
properly catch exceptions when running under async remote eager (as is the | |
case in TPU training setups involving separate coordinator/worker machines). | |
Args: | |
iterator: A nested structure of `tf.data.Iterator` or | |
`DistributedIterator`. | |
num_steps: The number of steps in the loop. If `num_steps == -1`, will | |
iterate until exausting the iterator. | |
state: An optional initial state before running the loop. | |
reduce_fn: A callable taking two inputs, `state` and `value`, where | |
`state` is the previous output from `reduce_fn`, and `value` is the | |
output from `step_fn`. | |
Returns: | |
The final state returned by `reduce_fn`, or `None` if `state` and | |
`reduce_fn` are not provided. | |
""" | |
step = 0 | |
try: | |
# To make sure the OutOfRangeError exception can be handled well under | |
# async remote eager, we need to wrap the loop body in `async_scope`. | |
with tf.experimental.async_scope(): | |
while num_steps == -1 or step < num_steps: | |
outputs = step_fn(iterator) | |
if reduce_fn is not None: | |
state = reduce_fn(state, outputs) | |
step += 1 | |
return state | |
except (StopIteration, tf.errors.OutOfRangeError): | |
logging.info("The dataset iterator is exhausted after %d steps.", step) | |
tf.experimental.async_clear_error() | |
return state | |
return loop_fn | |
def create_tf_while_loop_fn(step_fn): | |
"""Creates a loop function compatible with TF's AutoGraph loop conversion. | |
Args: | |
step_fn: A function taking a nested structure of `tf.data.Iterator` or | |
`DistributedIterator`. Currently, any return values are ignored. | |
Returns: | |
A loop function taking required `iterator` and `num_steps` parameters. If | |
called inside a `tf.function`, the loop will be converted by AutoGraph into | |
a `tf.while_loop` construct. See the `loop_fn` definition below for | |
additional details. | |
""" | |
def loop_fn(iterator, num_steps): | |
"""Makes `num_steps` calls to `step_fn(iterator)`. | |
Args: | |
iterator: A nested structure of `tf.data.Iterator` or | |
`DistributedIterator`. | |
num_steps: The number of steps in the loop. Should be passed as a | |
`tf.Tensor`. Iterating until iterator exhaustion is not supported. | |
""" | |
if not isinstance(num_steps, tf.Tensor): | |
raise ValueError( | |
"`num_steps` should be a `tf.Tensor`. Passing a Python value can " | |
"cause unnecessary retracing when wrapped by `tf.function`.") | |
for _ in tf.range(num_steps): | |
# Clear out the outer name scope so the ops created inside `tf.while_loop` | |
# don't get "while/" as name prefix. | |
with tf.name_scope(""): | |
step_fn(iterator) | |
return loop_fn | |
def create_tf_while_loop_fn_with_state(step_fn): | |
"""Creates a TF while loop function with state. | |
This function is similar to `create_tf_while_loop_fn`, but allowing a `state` | |
to be accumulated over multiple iterations of the loop. Note that the | |
structure of the `state` cannot be changed across iterations. | |
Args: | |
step_fn: A function taking a nested structure of `tf.data.Iterator` or | |
`DistributedIterator`. Currently, any return values are ignored. | |
Returns: | |
A loop function taking required `iterator`, `num_steps`, `state` and | |
`reduce_fn` parameters. If called inside a `tf.function`, the loop will be | |
converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn` | |
definition below for additional details. | |
""" | |
def loop_fn_with_state(iterator, num_steps, state, reduce_fn): | |
"""Makes `num_steps` calls to `step_fn(iterator)`. | |
Args: | |
iterator: A nested structure of `tf.data.Iterator` or | |
`DistributedIterator`. | |
num_steps: The number of steps in the loop. Should be passed as a | |
`tf.Tensor`. Iterating until iterator exhaustion is not supported. | |
state: An initial state before running the loop. | |
reduce_fn: A callable taking two inputs, `state` and `value`, where | |
`state` is the previous output from `reduce_fn`, and `value` is the | |
output from `step_fn`. | |
Returns: | |
The final state returned by `reduce_fn`. | |
""" | |
if not isinstance(num_steps, tf.Tensor): | |
raise ValueError( | |
"`num_steps` should be a `tf.Tensor`. Passing a Python value can " | |
"cause unnecessary retracing when wrapped by `tf.function`.") | |
def _get_relaxed_tensor_shape(t): | |
"""Returns a `TensorShape` with all `None` dimensions.""" | |
if not tf.is_tensor(t): | |
return None | |
shape = t.shape | |
if shape.rank is not None and shape.rank > 0: | |
return tf.TensorShape([None] * shape.rank) | |
return shape | |
def _get_relaxed_shape_structure(s): | |
"""Returns the relaxed shape of the input nested structure `s`.""" | |
return tf.nest.pack_sequence_as( | |
state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)]) | |
for _ in tf.range(num_steps): | |
# Clear out the outer name scope so the ops created inside `tf.while_loop` | |
# don't get "while/" as name prefix. | |
with tf.name_scope(""): | |
# Relax the shapes within the loop, so the shape of `state` can change | |
# across iterations. This is useful to aggregate outputs from each step | |
# and concat to `state`. | |
tf.autograph.experimental.set_loop_options( | |
shape_invariants=[(state, _get_relaxed_shape_structure(state))]) | |
outputs = step_fn(iterator) | |
state = reduce_fn(state, outputs) | |
return state | |
return loop_fn_with_state | |
class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction): | |
"""Implements a two-program approach for optimizing summaries on TPU. | |
This version works with the result of `create_tf_while_loop_fn`. | |
""" | |
def __call__(self, iterator, num_steps): | |
if tf.summary.should_record_summaries(): | |
output = self.with_summaries(iterator, tf.constant(1)) | |
num_steps -= 1 | |
if num_steps >= 1: | |
output = self.without_summaries(iterator, num_steps) | |
return output | |