File size: 4,523 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Tests for orbit.utils.tpu_summaries."""

import functools
import os

from orbit.utils import common
from orbit.utils import tpu_summaries

import tensorflow as tf, tf_keras


class TrainFunctionWithSummaries(tpu_summaries.OptionalSummariesFunction):
  """Implements a two-program approach for summaries on TPU."""

  def __call__(self, num_steps):
    if tf.summary.should_record_summaries():
      output = self.with_summaries(tf.constant(1))
      num_steps -= 1
    if num_steps >= 1:
      output = self.without_summaries(num_steps)
    return output


def train_function_with_summaries(function=None, **kwargs):
  if function is not None:
    return TrainFunctionWithSummaries(function, **kwargs)
  return functools.partial(TrainFunctionWithSummaries, **kwargs)


class DummyTrainer(tf.Module):

  def __init__(self):
    self.step_counter = common.create_global_step()

  @train_function_with_summaries
  def train_with_tpu_summary_optimization(self, num_steps):
    for _ in tf.range(num_steps):
      tf.summary.scalar("step", self.step_counter, step=self.step_counter)
      self.step_counter.assign_add(1)
    return self.step_counter

  @train_function_with_summaries(
      input_signature=[tf.TensorSpec((), dtype=tf.int32)])
  def train_with_tpu_summary_optimization_and_input_signature(self, num_steps):
    for _ in tf.range(num_steps):
      tf.summary.scalar("step", self.step_counter, step=self.step_counter)
      self.step_counter.assign_add(1)
    return self.step_counter

  def train_with_tpu_summary_optimization_no_decorator(self, num_steps):
    for _ in tf.range(num_steps):
      tf.summary.scalar("step", self.step_counter, step=self.step_counter)
      self.step_counter.assign_add(1)
    return self.step_counter


class TpuSummariesTest(tf.test.TestCase):

  def setUp(self):
    super().setUp()
    self.trainer = DummyTrainer()

  def _get_events_from_logdir(self, logdir):
    event_files = tf.io.gfile.listdir(logdir)
    self.assertLen(event_files, 1)
    path = os.path.join(logdir, event_files[0])
    events = list(tf.compat.v1.train.summary_iterator(path))
    return [event for event in events if event.WhichOneof("what") == "summary"]

  def _validate_tpu_summary_optimization(self, function, *args, **kwargs):
    logdir = self.get_temp_dir()
    with tf.summary.create_file_writer(logdir).as_default():
      with tf.summary.record_if(lambda: self.trainer.step_counter % 20 == 0):
        for _ in range(4):
          output = function(tf.constant(10), *args, **kwargs)
    events = self._get_events_from_logdir(logdir)
    self.assertLen(events, 2)
    self.assertEqual(events[0].step, 0)
    self.assertEqual(events[1].step, 20)
    return output

  def test_train_with_tpu_summary_optimization(self):
    output = self._validate_tpu_summary_optimization(
        self.trainer.train_with_tpu_summary_optimization)
    self.assertEqual(output, self.trainer.step_counter.numpy())

  def test_train_with_tpu_summary_optimization_no_decorator(self):
    optimized = train_function_with_summaries(
        self.trainer.train_with_tpu_summary_optimization_no_decorator)
    output = self._validate_tpu_summary_optimization(optimized)
    self.assertEqual(output, self.trainer.step_counter.numpy())

  def test_train_with_tpu_summary_optimization_and_input_signature(self):
    output = self._validate_tpu_summary_optimization(
        self.trainer.train_with_tpu_summary_optimization_and_input_signature)
    self.assertEqual(output, self.trainer.step_counter.numpy())
    function = self.trainer.train_with_tpu_summary_optimization_and_input_signature
    expected = (tf.TensorSpec((), dtype=tf.int32),)
    input_signature = function.with_summaries.input_signature
    self.assertEqual(input_signature, expected)
    input_signature = function.without_summaries.input_signature
    self.assertEqual(input_signature, expected)


if __name__ == "__main__":
  tf.test.main()