File size: 8,026 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93528c6
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Utilities for creating loop functions."""

from absl import logging
from orbit.utils import tpu_summaries

import tensorflow as tf, tf_keras


def create_loop_fn(step_fn):
  """Creates a loop function driven by a Python `while` loop.

  Args:
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. There are no constraints on the return value of the
      function (except that it must be compatible with any `reduce_fn` provided
      to the returned `loop_fn`).

  Returns:
    A loop function taking required `iterator` and `num_steps` parameters, as
    well as optional `state` and `reduce_fn` parameters for accumulating state
    over multiple iterations of the loop. See the `loop_fn` definition below for
    additional details.
  """

  def loop_fn(iterator, num_steps, state=None, reduce_fn=None):
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Additionally, state may be accumulated across iterations of the loop.
    Conceptually, state accumulation is handled roughly as follows:

        for _ in range(num_steps):
          step_outputs  = step_fn(iterator)
          state = reduce_fn(state, step_outputs)
        return state

    However, the implementation is slightly more complicated in order to support
    looping until the iterator is exhausted (when `num_steps == -1`) and to
    properly catch exceptions when running under async remote eager (as is the
    case in TPU training setups involving separate coordinator/worker machines).

    Args:
      iterator: A nested structure of `tf.data.Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. If `num_steps == -1`, will
        iterate until exausting the iterator.
      state: An optional initial state before running the loop.
      reduce_fn: A callable taking two inputs, `state` and `value`, where
        `state` is the previous output from `reduce_fn`, and `value` is the
        output from `step_fn`.

    Returns:
      The final state returned by `reduce_fn`, or `None` if `state` and
      `reduce_fn` are not provided.
    """
    step = 0
    try:
      # To make sure the OutOfRangeError exception can be handled well under
      # async remote eager, we need to wrap the loop body in `async_scope`.
      with tf.experimental.async_scope():
        while num_steps == -1 or step < num_steps:
          outputs = step_fn(iterator)
          if reduce_fn is not None:
            state = reduce_fn(state, outputs)
          step += 1
        return state
    except (StopIteration, tf.errors.OutOfRangeError):
      logging.info("The dataset iterator is exhausted after %d steps.", step)
      tf.experimental.async_clear_error()
      return state

  return loop_fn


def create_tf_while_loop_fn(step_fn):
  """Creates a loop function compatible with TF's AutoGraph loop conversion.

  Args:
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. Currently, any return values are ignored.

  Returns:
    A loop function taking required `iterator` and `num_steps` parameters. If
    called inside a `tf.function`, the loop will be converted by AutoGraph into
    a `tf.while_loop` construct. See the `loop_fn` definition below for
    additional details.
  """

  def loop_fn(iterator, num_steps):
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Args:
      iterator: A nested structure of `tf.data.Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. Should be passed as a
        `tf.Tensor`. Iterating until iterator exhaustion is not supported.
    """
    if not isinstance(num_steps, tf.Tensor):
      raise ValueError(
          "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
          "cause unnecessary retracing when wrapped by `tf.function`.")

    for _ in tf.range(num_steps):
      # Clear out the outer name scope so the ops created inside `tf.while_loop`
      # don't get "while/" as name prefix.
      with tf.name_scope(""):
        step_fn(iterator)

  return loop_fn


def create_tf_while_loop_fn_with_state(step_fn):
  """Creates a TF while loop function with state.

  This function is similar to `create_tf_while_loop_fn`, but allowing a `state`
  to be accumulated over multiple iterations of the loop. Note that the
  structure of the `state` cannot be changed across iterations.

  Args:
    step_fn: A function taking a nested structure of `tf.data.Iterator` or
      `DistributedIterator`. Currently, any return values are ignored.

  Returns:
    A loop function taking required `iterator`, `num_steps`, `state` and
    `reduce_fn` parameters. If called inside a `tf.function`, the loop will be
    converted by AutoGraph into a `tf.while_loop` construct. See the `loop_fn`
    definition below for additional details.
  """

  def loop_fn_with_state(iterator, num_steps, state, reduce_fn):
    """Makes `num_steps` calls to `step_fn(iterator)`.

    Args:
      iterator: A nested structure of `tf.data.Iterator` or
        `DistributedIterator`.
      num_steps: The number of steps in the loop. Should be passed as a
        `tf.Tensor`. Iterating until iterator exhaustion is not supported.
      state: An initial state before running the loop.
      reduce_fn: A callable taking two inputs, `state` and `value`, where
        `state` is the previous output from `reduce_fn`, and `value` is the
        output from `step_fn`.

    Returns:
      The final state returned by `reduce_fn`.
    """
    if not isinstance(num_steps, tf.Tensor):
      raise ValueError(
          "`num_steps` should be a `tf.Tensor`. Passing a Python value can "
          "cause unnecessary retracing when wrapped by `tf.function`.")

    def _get_relaxed_tensor_shape(t):
      """Returns a `TensorShape` with all `None` dimensions."""
      if not tf.is_tensor(t):
        return None

      shape = t.shape
      if shape.rank is not None and shape.rank > 0:
        return tf.TensorShape([None] * shape.rank)
      return shape

    def _get_relaxed_shape_structure(s):
      """Returns the relaxed shape of the input nested structure `s`."""
      return tf.nest.pack_sequence_as(
          state, [_get_relaxed_tensor_shape(t) for t in tf.nest.flatten(s)])

    for _ in tf.range(num_steps):
      # Clear out the outer name scope so the ops created inside `tf.while_loop`
      # don't get "while/" as name prefix.
      with tf.name_scope(""):
        # Relax the shapes within the loop, so the shape of `state` can change
        # across iterations. This is useful to aggregate outputs from each step
        # and concat to `state`.
        tf.autograph.experimental.set_loop_options(
            shape_invariants=[(state, _get_relaxed_shape_structure(state))])
        outputs = step_fn(iterator)
        state = reduce_fn(state, outputs)
    return state

  return loop_fn_with_state


class LoopFnWithSummaries(tpu_summaries.OptionalSummariesFunction):
  """Implements a two-program approach for optimizing summaries on TPU.

  This version works with the result of `create_tf_while_loop_fn`.
  """

  def __call__(self, iterator, num_steps):
    if tf.summary.should_record_summaries():
      output = self.with_summaries(iterator, tf.constant(1))
      num_steps -= 1
    if num_steps >= 1:
      output = self.without_summaries(iterator, num_steps)
    return output