Spaces:
Runtime error
Runtime error
File size: 4,287 Bytes
5672777 93528c6 5672777 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Provides a utility class for managing summary writing."""
import os
from orbit.utils.summary_manager_interface import SummaryManagerInterface
import tensorflow as tf, tf_keras
class SummaryManager(SummaryManagerInterface):
"""A utility class for managing summary writing."""
def __init__(self, summary_dir, summary_fn, global_step=None):
"""Initializes the `SummaryManager` instance.
Args:
summary_dir: The directory in which to write summaries. If `None`, all
summary writing operations provided by this class are no-ops.
summary_fn: A callable defined accepting `name`, `value`, and `step`
parameters, making calls to `tf.summary` functions to write summaries.
global_step: A `tf.Variable` containing the global step value.
"""
self._enabled = summary_dir is not None
self._summary_dir = summary_dir
self._summary_fn = summary_fn
self._summary_writers = {}
if global_step is None:
self._global_step = tf.summary.experimental.get_step()
else:
self._global_step = global_step
def summary_writer(self, relative_path=""):
"""Returns the underlying summary writer for a specific subdirectory.
Args:
relative_path: The current path in which to write summaries, relative to
the summary directory. By default it is empty, which corresponds to the
root directory.
"""
if self._summary_writers and relative_path in self._summary_writers:
return self._summary_writers[relative_path]
if self._enabled:
self._summary_writers[relative_path] = tf.summary.create_file_writer(
os.path.join(self._summary_dir, relative_path))
else:
self._summary_writers[relative_path] = tf.summary.create_noop_writer()
return self._summary_writers[relative_path]
def flush(self):
"""Flushes the underlying summary writers."""
if self._enabled:
tf.nest.map_structure(tf.summary.flush, self._summary_writers)
def write_summaries(self, summary_dict):
"""Writes summaries for the given dictionary of values.
This recursively creates subdirectories for any nested dictionaries
provided in `summary_dict`, yielding a hierarchy of directories which will
then be reflected in the TensorBoard UI as different colored curves.
For example, users may evaluate on multiple datasets and return
`summary_dict` as a nested dictionary:
{
"dataset1": {
"loss": loss1,
"accuracy": accuracy1
},
"dataset2": {
"loss": loss2,
"accuracy": accuracy2
},
}
This will create two subdirectories, "dataset1" and "dataset2", inside the
summary root directory. Each directory will contain event files including
both "loss" and "accuracy" summaries.
Args:
summary_dict: A dictionary of values. If any value in `summary_dict` is
itself a dictionary, then the function will create a subdirectory with
name given by the corresponding key. This is performed recursively. Leaf
values are then summarized using the summary writer instance specific to
the parent relative path.
"""
if not self._enabled:
return
self._write_summaries(summary_dict)
def _write_summaries(self, summary_dict, relative_path=""):
for name, value in summary_dict.items():
if isinstance(value, dict):
self._write_summaries(
value, relative_path=os.path.join(relative_path, name))
else:
with self.summary_writer(relative_path).as_default():
self._summary_fn(name, value, step=self._global_step)
|