Spaces:
Runtime error
Runtime error
# Copyright 2023 The TensorFlow Authors. All Rights Reserved. | |
# | |
# Licensed under the Apache License, Version 2.0 (the "License"); | |
# you may not use this file except in compliance with the License. | |
# You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, software | |
# distributed under the License is distributed on an "AS IS" BASIS, | |
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
# See the License for the specific language governing permissions and | |
# limitations under the License. | |
"""Standard Trainer implementation. | |
The base trainer implements the Orbit `StandardTrainable` and | |
`StandardEvaluable` interfaces. Trainers inside this project should be | |
interchangable and independent on model architectures and tasks. | |
""" | |
import functools | |
from typing import Union, Optional | |
from absl import logging | |
import gin | |
import orbit | |
import tensorflow as tf, tf_keras | |
from official.core import base_task | |
from official.core import config_definitions | |
from official.modeling import optimization | |
ExperimentConfig = config_definitions.ExperimentConfig | |
TrainerConfig = config_definitions.TrainerConfig | |
class _AsyncTrainer(orbit.StandardTrainer, orbit.StandardEvaluator): | |
"""Trainer class for both sync and async Strategy.""" | |
def init_async(self): | |
"""Initializes the Async Trainer base class.""" | |
assert isinstance(self._strategy, tf.distribute.Strategy) | |
self._is_async = isinstance( | |
self._strategy, tf.distribute.experimental.ParameterServerStrategy) | |
self._coordinator = None | |
if self._is_async: | |
self._coordinator = ( | |
tf.distribute.experimental.coordinator.ClusterCoordinator( | |
self._strategy)) | |
def coordinator_for_async( | |
self, | |
) -> tf.distribute.experimental.coordinator.ClusterCoordinator: | |
if not self._coordinator: | |
raise ValueError( | |
"Coordinator uninitialized for async run. Call init_async() first." | |
) | |
return self._coordinator | |
def join(self): | |
"""Join all async steps. Only useful in aysnc training.""" | |
if getattr(self, "_is_async", False): | |
self.coordinator_for_async().join() | |
def create_train_loop_fn(self): | |
"""Creates a eval loop from the given step function and options.""" | |
train_loop_fn = super().create_train_loop_fn() | |
if getattr(self, "_is_async", False): | |
def _async_loop_fn(iterator, num_steps): | |
self.coordinator_for_async().schedule( | |
train_loop_fn, args=(iterator, num_steps) | |
) | |
return _async_loop_fn | |
else: | |
return train_loop_fn | |
def create_eval_loop_fn(self, has_state: bool): | |
"""Creates a training loop from the given step function and options.""" | |
eval_loop_fn = super().create_eval_loop_fn(has_state) | |
if getattr(self, "_is_async", False): | |
if has_state: | |
raise ValueError( | |
"Stateful eval loop is not supported in async training.") | |
def _async_loop_fn(iterator, num_steps, state=None, reduce_fn=None): | |
assert state is None | |
assert reduce_fn is None | |
self.coordinator_for_async().schedule( | |
eval_loop_fn, args=(iterator, num_steps) | |
) | |
return _async_loop_fn | |
else: | |
return eval_loop_fn | |
def distribute_dataset(self, dataset_or_fn, *args, **kwargs): | |
"""A utility function to help create a `tf.distribute.DistributedDataset`. | |
Args: | |
dataset_or_fn: A instance of `tf.data.Dataset`, or a "dataset function" | |
returning a `tf.data.Dataset`. If it is a function, it may optionally | |
have an argument named `input_context` which will be passed a | |
`tf.distribute.InputContext` instance. | |
*args: Any positional arguments to pass through to `dataset_or_fn`. | |
**kwargs: Any keyword arguments to pass through to `dataset_or_fn`. | |
Returns: | |
A distributed Dataset. | |
""" | |
if getattr(self, "_is_async", False): | |
per_worker_dataset_fn = functools.partial( | |
orbit.utils.make_distributed_dataset, self._strategy, dataset_or_fn, | |
*args, **kwargs) | |
per_worker_dataset_fn = tf.function(per_worker_dataset_fn) | |
return self.coordinator_for_async().create_per_worker_dataset( | |
per_worker_dataset_fn | |
) | |
else: | |
return orbit.utils.make_distributed_dataset(self._strategy, dataset_or_fn, | |
*args, **kwargs) | |
def get_runtime_options(config: ExperimentConfig): | |
"""Get tf.distribute.RunOptions from config.""" | |
xla_options = {} | |
if config.runtime.tpu_enable_xla_dynamic_padder is not None: | |
xla_options["enable_xla_dynamic_padder"] = ( | |
config.runtime.tpu_enable_xla_dynamic_padder) | |
return tf.distribute.RunOptions( | |
experimental_xla_options=tf.tpu.XLAOptions(**xla_options)) | |
class Trainer(_AsyncTrainer): | |
"""Implements the common trainer shared for TensorFlow models.""" | |
# pylint: disable=super-init-not-called | |
def __init__( | |
self, | |
config: ExperimentConfig, | |
task: base_task.Task, | |
model: tf_keras.Model, | |
optimizer: tf.optimizers.Optimizer, | |
train: bool = True, | |
evaluate: bool = True, | |
train_dataset: Optional[Union[tf.data.Dataset, | |
tf.distribute.DistributedDataset]] = None, | |
validation_dataset: Optional[Union[ | |
tf.data.Dataset, tf.distribute.DistributedDataset]] = None, | |
checkpoint_exporter=None): | |
"""Initialize common trainer for TensorFlow models. | |
Args: | |
config: An `ExperimentConfig` instance specifying experiment config. | |
task: A base_task.Task instance. | |
model: The model instance, e.g. a tf_keras.Model instance. | |
optimizer: tf.optimizers.Optimizer instance. | |
train: bool, whether or not this trainer will be used for training. | |
default to True. | |
evaluate: bool, whether or not this trainer will be used for evaluation. | |
default to True. | |
train_dataset: a dataset object created for training. With tf.distribute, | |
it needs to be a `DistributedDataset`. | |
validation_dataset: a dataset object created for evaluation. With | |
tf.distribute, it needs to be a `DistributedDataset`. The evaluator will | |
create a dataset iterator for each eval round, so the dataset does not | |
need to repeat. | |
checkpoint_exporter: an object that has the `maybe_export_checkpoint` | |
interface. | |
""" | |
# Gets the current distribution strategy. If not inside any strategy scope, | |
# it gets a single-replica no-op strategy. | |
self._strategy = tf.distribute.get_strategy() | |
self._validate_params( | |
config, | |
check_train_data=train_dataset is None, | |
check_validation_data=validation_dataset is None) | |
self._config = config | |
self._task = task | |
self._model = model | |
self._optimizer = optimizer | |
self._checkpoint_exporter = checkpoint_exporter | |
self._recovery = None | |
# Runtime options are only applied to train_step. | |
# We use default for eval_step. | |
self._runtime_options = get_runtime_options(config) | |
# Creates a shadow copy of the weights to store weights moving average. | |
if isinstance(self._optimizer, optimization.ExponentialMovingAverage | |
) and not self._optimizer.has_shadow_copy: | |
self._optimizer.shadow_copy(self._model) | |
# global_step increases by 1 after each training iteration. | |
# We should have global_step.numpy() == self.optimizer.iterations.numpy() | |
# when there is only 1 optimizer. | |
self._global_step = orbit.utils.create_global_step() | |
if hasattr(self.model, "checkpoint_items"): | |
checkpoint_items = self.model.checkpoint_items | |
else: | |
checkpoint_items = {} | |
self._checkpoint = tf.train.Checkpoint( | |
global_step=self.global_step, | |
model=self.model, | |
optimizer=self.optimizer, | |
**checkpoint_items) | |
self._train_loss = tf_keras.metrics.Mean("training_loss", dtype=tf.float32) | |
self._validation_loss = tf_keras.metrics.Mean( | |
"validation_loss", dtype=tf.float32) | |
model_metrics = model.metrics if hasattr(model, "metrics") else [] | |
self.init_async() | |
if train: | |
self._train_metrics = self.task.build_metrics( | |
training=True) + model_metrics | |
train_dataset = train_dataset or self.distribute_dataset( | |
self.task.build_inputs, self.config.task.train_data) | |
orbit.StandardTrainer.__init__( | |
self, | |
train_dataset, | |
options=orbit.StandardTrainerOptions( | |
use_tf_while_loop=config.trainer.train_tf_while_loop, | |
use_tf_function=config.trainer.train_tf_function, | |
use_tpu_summary_optimization=config.trainer.allow_tpu_summary)) | |
if evaluate: | |
self._validation_metrics = self.task.build_metrics( | |
training=False) + model_metrics | |
validation_dataset = validation_dataset or self.distribute_dataset( | |
self.task.build_inputs, self.config.task.validation_data) | |
orbit.StandardEvaluator.__init__( | |
self, | |
validation_dataset, | |
options=orbit.StandardEvaluatorOptions( | |
use_tf_function=config.trainer.eval_tf_function, | |
use_tf_while_loop=config.trainer.eval_tf_while_loop)) | |
def _validate_params(self, | |
config, | |
check_train_data=True, | |
check_validation_data=True): | |
r"""Validates if the configuration object passed to the Trainer. | |
The experiment configuration should be structured as: | |
\trainer | |
\task | |
\train_data | |
\validation_data | |
Args: | |
config: a namedtuple, dataclass, ConfigDict, etc. | |
check_train_data: whether to check task.train_data field. | |
check_validation_data: whether to check task.validation_data field. | |
""" | |
if not hasattr(config, "trainer"): | |
raise AttributeError("The trainer requires the configuration contains an" | |
" attribute `trainer`.") | |
if not hasattr(config, "task"): | |
raise AttributeError("The trainer requires the configuration contains an" | |
" attribute `task`.") | |
if check_train_data and not hasattr(config.task, "train_data"): | |
raise AttributeError("The trainer requires the configuration contains an" | |
" attribute `task.train_data`.") | |
if check_validation_data and not hasattr(config.task, "validation_data"): | |
raise AttributeError("The trainer requires the configuration contains an" | |
" attribute `task.validation_data`.") | |
def strategy(self): | |
return self._strategy | |
def config(self): | |
return self._config | |
def task(self): | |
return self._task | |
def model(self): | |
return self._model | |
def optimizer(self): | |
if hasattr(self, "_optimizer"): | |
return self._optimizer | |
else: | |
return None | |
def global_step(self): | |
return self._global_step | |
def train_loss(self): | |
"""Accesses the training loss metric object.""" | |
return self._train_loss | |
def validation_loss(self): | |
"""Accesses the validation loss metric object.""" | |
return self._validation_loss | |
def train_metrics(self): | |
"""Accesses all training metric objects.""" | |
return self._train_metrics | |
def validation_metrics(self): | |
"""Accesses all validation metric metric objects.""" | |
return self._validation_metrics | |
def initialize(self): | |
"""A callback function. | |
This function will be called when no checkpoint found for the model. | |
If there is a checkpoint, the checkpoint will be loaded and this function | |
will not be called. Tasks may use this callback function to load a | |
pretrained checkpoint, saved under a directory other than the model_dir. | |
""" | |
self.task.initialize(self.model) | |
def checkpoint(self): | |
"""Accesses the training checkpoint.""" | |
return self._checkpoint | |
def checkpoint_exporter(self): | |
"""Accesses the checkpoint exporter.""" | |
return self._checkpoint_exporter | |
def train_loop_end(self): | |
"""See base class.""" | |
self.join() | |
logs = {} | |
for metric in self.train_metrics + [self.train_loss]: | |
logs[metric.name] = metric.result() | |
metric.reset_states() | |
if callable(self.optimizer.learning_rate): | |
# Maybe a self-implemented optimizer does not have `optimizer.iterations`. | |
# So just to be safe here. | |
if hasattr(self.optimizer, "iterations"): | |
logs["learning_rate"] = self.optimizer.learning_rate( | |
self.optimizer.iterations) | |
else: | |
logs["learning_rate"] = self.optimizer.learning_rate(self.global_step) | |
else: | |
logs["learning_rate"] = self.optimizer.learning_rate | |
return logs | |
def next_train_inputs(self, iterator): | |
"""Fetches the next inputs for the model during train. | |
This method consumes the input iterator and returns the next inputs for the | |
model. | |
This method provides a way to control how to fetch the next model input, and | |
what data to send to the model. | |
Note: This function runs on the host side when accelerators are used. | |
Note: Depending on the training setup this may or may not run in eager mode. | |
In most cases it will be run in graph mode. | |
Args: | |
iterator: Dataset iterator to generate the next inputs from. | |
Returns: | |
The inputs to the model. | |
""" | |
return next(iterator) | |
def train_step(self, iterator): | |
"""See base class.""" | |
def step_fn(inputs): | |
if self.config.runtime.enable_xla and (self.config.runtime.num_gpus > 0): | |
task_train_step = tf.function(self.task.train_step, jit_compile=True) | |
else: | |
task_train_step = self.task.train_step | |
logs = task_train_step( | |
inputs, | |
model=self.model, | |
optimizer=self.optimizer, | |
metrics=self.train_metrics) | |
self._train_loss.update_state(logs[self.task.loss]) | |
self.global_step.assign_add(1) | |
inputs = self.next_train_inputs(iterator) | |
self.strategy.run(step_fn, args=(inputs,), options=self._runtime_options) | |
def eval_begin(self): | |
"""Sets up metrics.""" | |
for metric in self.validation_metrics + [self.validation_loss]: | |
metric.reset_states() | |
# Swaps weights to test on weights moving average. | |
if self.optimizer and isinstance(self.optimizer, | |
optimization.ExponentialMovingAverage): | |
self.optimizer.swap_weights() | |
def next_eval_inputs(self, iterator): | |
"""Fetches the next inputs for the model during eval. | |
This method consumes the input iterator and returns the next inputs for the | |
model and an additional logs dict. The output dict remains in the host (not | |
sent to GPUs/TPUs) and is merged with the model outputs which will be | |
processed later in `aggregate_logs`. This is useful for sending extra logs | |
downstream that are not compatible with the accelerators. | |
Note: This function runs on the host side when accelerators are used. | |
Note: Depending on the training setup this may or may not run in eager mode. | |
In most cases it will be run in graph mode. | |
Args: | |
iterator: Dataset iterator to generate the next inputs from. | |
Returns: | |
The inputs to the model, and an additional logs dictionnary. The logs | |
are not passed to the model, instead they are merged with model output | |
logs. | |
""" | |
passthrough_logs = dict() | |
return next(iterator), passthrough_logs | |
def eval_step(self, iterator): | |
"""See base class.""" | |
def step_fn(inputs): | |
logs = self.task.validation_step( | |
inputs, model=self.model, metrics=self.validation_metrics) | |
if self.task.loss in logs: | |
self._validation_loss.update_state(logs[self.task.loss]) | |
return logs | |
inputs, passthrough_logs = self.next_eval_inputs(iterator) | |
distributed_outputs = self.strategy.run(step_fn, args=(inputs,)) | |
logs = tf.nest.map_structure( | |
self.strategy.experimental_local_results, distributed_outputs | |
) | |
if set(logs.keys()) & set(passthrough_logs.keys()): | |
logging.warning( | |
( | |
"Conflict between the pasthrough log keys and the returned model" | |
" log keys. Found %r keys in the passthrough logs and %r keys in" | |
" the model logs. Model log keys takes precedence." | |
), | |
logs.keys(), | |
passthrough_logs.keys(), | |
) | |
return passthrough_logs | logs | |
def eval_end(self, aggregated_logs=None): | |
"""Processes evaluation results.""" | |
self.join() | |
logs = {} | |
for metric in self.validation_metrics: | |
logs[metric.name] = metric.result() | |
if self.validation_loss.count.numpy() != 0: | |
logs[self.validation_loss.name] = self.validation_loss.result() | |
else: | |
# `self.validation_loss` metric was not updated, because the validation | |
# loss was not returned from the task's `validation_step` method. | |
logging.info("The task did not report validation loss.") | |
if aggregated_logs: | |
metrics = self.task.reduce_aggregated_logs( | |
aggregated_logs, global_step=self.global_step) | |
logs.update(metrics) | |
if self._checkpoint_exporter: | |
self._checkpoint_exporter.maybe_export_checkpoint( | |
self.checkpoint, logs, self.global_step.numpy()) | |
metric_name = self.config.trainer.best_checkpoint_eval_metric | |
logs["best_" + | |
metric_name] = self._checkpoint_exporter.best_ckpt_logs[metric_name] | |
# Swaps back weights after testing when EMA is used. | |
# This happens after best checkpoint export so that average weights used for | |
# eval are exported instead of regular weights. | |
if self.optimizer and isinstance(self.optimizer, | |
optimization.ExponentialMovingAverage): | |
self.optimizer.swap_weights() | |
return logs | |
def eval_reduce(self, state=None, step_outputs=None): | |
return self.task.aggregate_logs(state, step_outputs) | |