Spaces:
Runtime error
Runtime error
File size: 8,026 Bytes
5672777 93528c6 5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 |
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Utilities for creating loop functions."""
from absl import logging
from orbit.utils import tpu_summaries
import tensorflow as tf, tf_keras
def create_loop_fn(step_fn):
"""Creates a loop function driven by a Python `while` loop.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. There are no constraints on the return value of the
function (except that it must be compatible with any `reduce_fn` provided
to the returned `loop_fn`).
Returns:
A loop function taking required `iterator` and `num_steps` parameters, as
well as optional `state` and `reduce_fn` parameters for accumulating state
over multiple iterations of the loop. See the `loop_fn` definition below for
additional details.
"""
def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Additionally, state may be accumulated across iterations of the loop.
Conceptually, state accumulation is handled roughly as follows:
for _ in range(num_steps):
step_outputs = step_fn(iterator)
state = reduce_fn(state, step_outputs)
return state
However, the implementation is slightly more complicated in order to support
looping until the iterator is exhausted (when `num_steps == -1`) and to
properly catch exceptions when running under async remote eager (as is the
case in TPU training setups involving separate coordinator/worker machines).
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. If `num_steps == -1`, will
iterate until exausting the iterator.
state: An optional initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`, or `None` if `state` and
`reduce_fn` are not provided.
"""
step = 0
try:
# To make sure the OutOfRangeError exception can be handled well under
# async remote eager, we need to wrap the loop body in `async_scope`.
with tf.experimental.async_scope():
while num_steps == -1 or step < num_steps:
outputs = step_fn(iterator)
if reduce_fn is not None:
state = reduce_fn(state, outputs)
step += 1
return state
except (StopIteration, tf.errors.OutOfRangeError):
logging.info("The dataset iterator is exhausted after %d steps.", step)
tf.experimental.async_clear_error()
return state
return loop_fn
def create_tf_while_loop_fn(step_fn):
"""Creates a loop function compatible with TF's AutoGraph loop conversion.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator` and `num_steps` parameters. If
called inside a `tf.function`, the loop will be converted by AutoGraph into
a `tf.while_loop` construct. See the `loop_fn` definition below for
additional details.
"""
def loop_fn(iterator, num_steps):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.")
for _ in tf.range(num_steps):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with tf.name_scope(""):
step_fn(iterator)
return loop_fn
def create_tf_while_loop_fn_with_state(step_fn):
"""Creates a TF while loop function with state.
This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
to be accumulated over multiple iterations of the loop. Note that the
structure of the `state` cannot be changed across iterations.
Args:
step_fn: A function taking a nested structure of `tf.data.Iterator` or
`DistributedIterator`. Currently, any return values are ignored.
Returns:
A loop function taking required `iterator`, `num_steps`, `state` and
`reduce_fn` parameters. If called inside a `tf.function`, the loop will be
converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
definition below for additional details.
"""
def loop_fn_with_state(iterator, num_steps, state, reduce_fn):
"""Makes `num_steps` calls to `step_fn(iterator)`.
Args:
iterator: A nested structure of `tf.data.Iterator` or
`DistributedIterator`.
num_steps: The number of steps in the loop. Should be passed as a
`tf.Tensor`. Iterating until iterator exhaustion is not supported.
state: An initial state before running the loop.
reduce_fn: A callable taking two inputs, `state` and `value`, where
`state` is the previous output from `reduce_fn`, and `value` is the
output from `step_fn`.
Returns:
The final state returned by `reduce_fn`.
"""
if not isinstance(num_steps, tf.Tensor):
raise ValueError(
"`num_steps` should be a `tf.Tensor`. Passing a Python value can "
"cause unnecessary retracing when wrapped by `tf.function`.")
def _get_relaxed_tensor_shape(t):
"""Returns a `TensorShape` with all `None` dimensions."""
if not tf.is_tensor(t):
return None
shape = t.shape
if shape.rank is not None and shape.rank > 0:
return tf.TensorShape([None] * shape.rank)
return shape
def _get_relaxed_shape_structure(s):
"""Returns the relaxed shape of the input nested structure `s`."""
return tf.nest.pack_sequence_as(
state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)])
for _ in tf.range(num_steps):
# Clear out the outer name scope so the ops created inside `tf.while_loop`
# don't get "while/" as name prefix.
with tf.name_scope(""):
# Relax the shapes within the loop, so the shape of `state` can change
# across iterations. This is useful to aggregate outputs from each step
# and concat to `state`.
tf.autograph.experimental.set_loop_options(
shape_invariants=[(state, _get_relaxed_shape_structure(state))])
outputs = step_fn(iterator)
state = reduce_fn(state, outputs)
return state
return loop_fn_with_state
class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
"""Implements a two-program approach for optimizing summaries on TPU.
This version works with the result of `create_tf_while_loop_fn`.
"""
def __call__(self, iterator, num_steps):
if tf.summary.should_record_summaries():
output = self.with_summaries(iterator, tf.constant(1))
num_steps -= 1
if num_steps >= 1:
output = self.without_summaries(iterator, num_steps)
return output
|