File size: 3,519 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
# Copyright 2023 The Orbit Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Provides AbstractTrainer/Evaluator base classes, defining train/eval APIs."""

import abc

from typing import Dict, Optional, Union

import numpy as np
import tensorflow as tf, tf_keras


Output = Dict[str, Union[tf.Tensor, float, np.number, np.ndarray, 'Output']]  # pytype: disable=not-supported-yet


class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta):
  """An abstract class defining the API required for training."""

  @abc.abstractmethod
  def train(self, num_steps: tf.Tensor) -> Optional[Output]:
    """Implements `num_steps` steps of training.

    This method will be called by the `Controller` to perform the "inner loop"
    of training. This inner loop amortizes the cost of bookkeeping associated
    with checkpointing, evaluation, and writing summaries. Additionally, the
    inner loop can be implemented (if desired) using TensorFlow's looping
    constructs (e.g. a `for` loop over a `tf.range` inside a `tf.function`),
    which can be necessary for getting optimal performance when running on TPU.
    For cases that don't require peak performance, a simple Python loop can be
    used instead for simplicity.

    Args:
      num_steps: The number of training steps to run. Note that it is up to the
        model what constitutes a "step", which may involve more than one update
        to model parameters (e.g., if training a GAN).

    Returns:
      Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values.
      If a dictionary is returned, it will be written to logs and as TensorBoard
      summaries. The dictionary may also be nested, which will generate a
      hierarchy of summary directories.
    """
    pass


class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta):
  """An abstract class defining the API required for evaluation."""

  @abc.abstractmethod
  def evaluate(self, num_steps: tf.Tensor) -> Optional[Output]:
    """Implements `num_steps` steps of evaluation.

    This method will by called the `Controller` to perform an evaluation. The
    `num_steps` parameter specifies the number of steps of evaluation to run,
    which is specified by the user when calling one of the `Controller`'s
    evaluation methods. A special sentinel value of `-1` is reserved to indicate
    evaluation should run until the underlying data source is exhausted.

    Args:
      num_steps: The number of evaluation steps to run. Note that it is up to
        the model what constitutes a "step". Evaluations may also want to
        support "complete" evaluations when `num_steps == -1`, running until a
        given data source is exhausted.

    Returns:
      Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values.
      If a dictionary is returned, it will be written to logs and as TensorBoard
      summaries. The dictionary may also be nested, which will generate a
      hierarchy of summary directories.
    """
    pass