# Copyright 2023 The Orbit Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Contains utilities for TPU summary optimization.""" import contextlib import functools import tensorflow as tf, tf_keras @contextlib.contextmanager def _soft_device_placement(): """Context manager for soft device placement, allowing summaries on CPU.""" original_setting = tf.config.get_soft_device_placement() try: tf.config.set_soft_device_placement(True) yield finally: tf.config.set_soft_device_placement(original_setting) class OptionalSummariesFunction: """Wrapper that provides versions of a function with and without summaries. This is a utility class for implementing optimized summary recording via a two-function approach, specifically important for TPUs. Two `tf.function` versions of a given `function` are created: one with soft device placement enabled (for use on steps that require summary writing), and one with summary writing and soft device placement entirely disabled (for use on all other steps). This removes any performance impact of summaries on steps where they aren't recorded (b/148418718). This class can be used as a base class to implement summary optimizations for a function with a specific signature. For example, to implement efficient TPU summaries for a standard `train()` method (as in `orbit.AbstractTrainer`): class TrainFunctionWithSummaries(orbit.utils.OptionalSummariesFunction): '''Implements a two-program approach for summaries on TPU.''' def __call__(self, num_steps): if tf.summary.should_record_summaries(): output = self.with_summaries(tf.constant(1)) num_steps -= 1 if num_steps >= 1: output = self.without_summaries(num_steps) return output This can be used directly or to implement a decorator: def train_function_with_summaries(function=None, **kwargs): if function is not None: return TrainFunctionWithSummaries(function, **kwargs) return functools.partial(TrainFunctionWithSummaries, **kwargs) The decorator can be applied directly to `train()` methods: @train_function_with_summaries def train(self, num_steps): ... A similar approach approach can be implemented for functions with different signatures. Note: The above approach assumes that the frequency of summary writing is based on a step interval that is divisible by the number of steps executed in each call to the `train()` function. This is enforced by the `orbit.Controller`. This wrapper properly handles instance methods (see `__get__`). Attributes: with_summaries: A wrapped version of the underlying function with summaries enabled (using whatever the active predicate is for `tf.summary.record_if`), and placed inside a "soft device placement" context to enable summary recording on TPU. without_summaries: A wrapped version of the underlying function with all summary recording disabled. """ def __init__(self, function, **tf_function_kwargs): """Constructs an instance wrapping the given `function`. The given `function` is wrapped twice: Once in a "soft device placement" context (allowing summaries to also run on TPU), and once with summary recording entirely disabled. Both of these versions are compiled via `tf.function` (optionally using any supplied `tf.function` settings), and made available as attributes. Args: function: The underlying function to wrap. **tf_function_kwargs: Additional arguments to pass to `tf.function`. """ @tf.function(**tf_function_kwargs) @functools.wraps(function) def with_summaries(*args, **kwargs): with _soft_device_placement(): return function(*args, **kwargs) @tf.function(**tf_function_kwargs) @functools.wraps(function) def without_summaries(*args, **kwargs): with tf.summary.record_if(False): return function(*args, **kwargs) self.with_summaries = with_summaries self.without_summaries = without_summaries def __get__(self, instance, owner): """Allows this class to be used to wrap methods as well as free functions. For `tf.function` to work properly in all cases (e.g., when an input_signature is specified), any `tf.function`-converted methods must be properly bound to an instance if they are called as an instance method. This is done by implementing this `__get__` method of the descriptor protocol, and forwarding to the `__get__` method on the underlying `tf.function`s. Args: instance: The instance to bind to. owner: The class type of the instance. Returns: A new bound instance of `TpuDiscretionarySummariesFunctions`. """ new = object.__new__(self.__class__) # pytype: disable=attribute-error # See b/162476201. new.with_summaries = self.with_summaries.__get__(instance, owner) new.without_summaries = self.without_summaries.__get__(instance, owner) # pytype: enable=attribute-error return new