Spaces:
Runtime error
Runtime error
# Copyright 2023 The Orbit Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Provides AbstractTrainer/Evaluator base classes, defining train/eval APIs.""" | |
import abc | |
from typing import Dict, Optional, Union | |
import numpy as np | |
import tensorflow as tf, tf_keras | |
Output = Dict[str, Union[tf.Tensor, float, np.number, np.ndarray, 'Output']] # pytype: disable=not-supported-yet | |
class AbstractTrainer(tf.Module, metaclass=abc.ABCMeta): | |
"""An abstract class defining the API required for training.""" | |
def train(self, num_steps: tf.Tensor) -> Optional[Output]: | |
"""Implements `num_steps` steps of training. | |
This method will be called by the `Controller` to perform the "inner loop" | |
of training. This inner loop amortizes the cost of bookkeeping associated | |
with checkpointing, evaluation, and writing summaries. Additionally, the | |
inner loop can be implemented (if desired) using TensorFlow's looping | |
constructs (e.g. a `for` loop over a `tf.range` inside a `tf.function`), | |
which can be necessary for getting optimal performance when running on TPU. | |
For cases that don't require peak performance, a simple Python loop can be | |
used instead for simplicity. | |
Args: | |
num_steps: The number of training steps to run. Note that it is up to the | |
model what constitutes a "step", which may involve more than one update | |
to model parameters (e.g., if training a GAN). | |
Returns: | |
Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values. | |
If a dictionary is returned, it will be written to logs and as TensorBoard | |
summaries. The dictionary may also be nested, which will generate a | |
hierarchy of summary directories. | |
""" | |
pass | |
class AbstractEvaluator(tf.Module, metaclass=abc.ABCMeta): | |
"""An abstract class defining the API required for evaluation.""" | |
def evaluate(self, num_steps: tf.Tensor) -> Optional[Output]: | |
"""Implements `num_steps` steps of evaluation. | |
This method will by called the `Controller` to perform an evaluation. The | |
`num_steps` parameter specifies the number of steps of evaluation to run, | |
which is specified by the user when calling one of the `Controller`'s | |
evaluation methods. A special sentinel value of `-1` is reserved to indicate | |
evaluation should run until the underlying data source is exhausted. | |
Args: | |
num_steps: The number of evaluation steps to run. Note that it is up to | |
the model what constitutes a "step". Evaluations may also want to | |
support "complete" evaluations when `num_steps == -1`, running until a | |
given data source is exhausted. | |
Returns: | |
Either `None`, or a dictionary mapping names to `Tensor`s or NumPy values. | |
If a dictionary is returned, it will be written to logs and as TensorBoard | |
summaries. The dictionary may also be nested, which will generate a | |
hierarchy of summary directories. | |
""" | |
pass | |