deanna-emery's picture
updates
93528c6
raw
history blame
11.6 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Multitask training driver library."""
# pytype: disable=attribute-error
import os
from typing import Any, List, Mapping, Optional, Tuple, Union
from absl import logging
import orbit
import tensorflow as tf, tf_keras
from official.core import base_task
from official.core import base_trainer as core_lib
from official.core import train_utils
from official.modeling.multitask import base_model
from official.modeling.multitask import base_trainer
from official.modeling.multitask import configs
from official.modeling.multitask import evaluator as evaluator_lib
from official.modeling.multitask import interleaving_trainer
from official.modeling.multitask import multitask
from official.modeling.multitask import task_sampler
TRAINERS = {
'interleaving': interleaving_trainer.MultiTaskInterleavingTrainer,
'joint': base_trainer.MultiTaskBaseTrainer
}
def run_experiment(
*,
distribution_strategy: tf.distribute.Strategy,
task: multitask.MultiTask,
model: base_model.MultiTaskBaseModel,
mode: str,
params: configs.MultiTaskExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
trainer: base_trainer.MultiTaskBaseTrainer = None,
eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
best_ckpt_exporter_creator: Optional[Any] = train_utils
.maybe_create_best_ckpt_exporter
) -> Union[base_model.MultiTaskBaseModel, Tuple[base_model.MultiTaskBaseModel,
Mapping[Any, Any]]]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
task: A MultiTaskTask instance.
model: A MultiTaskBaseModel instance.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: ExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
trainer: (optional) A multi-task trainer to use. If none is provided, a
default one will be created based on `params`.
eval_summary_manager: Instance of the eval summary manager. If set, the
`eval_summary_dir` will be ignored. Otherwise the eval summary manager
will be created internally for TensorBoard summaries by default from the
`eval_summary_dir`.
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
Returns:
model: `base_model.MultiTaskBaseModel` instance.
"""
is_training = 'train' in mode
is_eval = 'eval' in mode
with distribution_strategy.scope():
optimizer = train_utils.create_optimizer(task, params)
kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer)
if params.trainer.trainer_type == 'interleaving':
sampler = task_sampler.get_task_sampler(params.trainer.task_sampler,
task.task_weights)
kwargs.update(dict(task_sampler=sampler))
if trainer is None:
trainer = TRAINERS[params.trainer.trainer_type](
**kwargs) if is_training else None
if is_eval:
eval_steps = task.task_eval_steps
evaluator = evaluator_lib.MultiTaskEvaluator(
eval_tasks=task.tasks.values(),
model=model,
eval_steps=eval_steps,
global_step=trainer.global_step if is_training else None,
checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
else:
evaluator = None
if trainer:
checkpoint = trainer.checkpoint
global_step = trainer.global_step
else:
checkpoint = evaluator.checkpoint
global_step = evaluator.global_step
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=model.initialize)
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer,
evaluator=evaluator,
global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train'),
eval_summary_dir=os.path.join(model_dir, 'validation'),
eval_summary_manager=eval_summary_manager,
summary_interval=params.trainer.summary_interval)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if evaluator.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if run_post_eval:
return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras
else:
return model
def run_experiment_with_multitask_eval(
*,
distribution_strategy: tf.distribute.Strategy,
train_task: base_task.Task,
eval_tasks: List[base_task.Task],
mode: str,
params: configs.MultiEvalExperimentConfig,
model_dir: str,
run_post_eval: bool = False,
save_summary: bool = True,
trainer: Optional[core_lib.Trainer] = None,
eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None,
best_ckpt_exporter_creator: Optional[Any] = train_utils
.maybe_create_best_ckpt_exporter,
) -> Tuple[Any, Any]:
"""Runs train/eval configured by the experiment params.
Args:
distribution_strategy: A distribution distribution_strategy.
train_task: A base_task.Task instance.
eval_tasks: A list of evaluation tasks.
mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval'
or 'continuous_eval'.
params: MultiEvalExperimentConfig instance.
model_dir: A 'str', a path to store model checkpoints and summaries.
run_post_eval: Whether to run post eval once after training, metrics logs
are returned.
save_summary: Whether to save train and validation summary.
trainer: the core_lib.Trainer instance. It should be created within the
strategy.scope(). If not provided, an instance will be created by default
if `mode` contains 'train'.
eval_summary_manager: Instance of the eval summary manager. If set, the
`eval_summary_dir` will be ignored. Otherwise the eval summary manager
will be created internally for TensorBoard summaries by default from the
`eval_summary_dir`.
best_ckpt_exporter_creator: A functor for creating best checkpoint exporter.
Returns:
model: `tf_keras.Model` instance.
"""
is_training = 'train' in mode
is_eval = 'eval' in mode
with distribution_strategy.scope():
if is_training:
trainer = trainer or core_lib.Trainer(
config=params,
task=train_task,
model=train_task.build_model(),
optimizer=train_utils.create_optimizer(train_task, params),
train=True,
evaluate=False)
else:
trainer = None
# Build the model or fetch the pre-cached one (which could be either
# multi-task model or single task model).
model = None
if trainer is None:
if isinstance(train_task, multitask.MultiTask):
model = train_task.build_multitask_model()
else:
model = train_task.build_model()
else:
if isinstance(trainer, base_trainer.MultiTaskBaseTrainer):
model = trainer.multi_task_model
else:
model = trainer.model
if is_eval:
eval_steps = dict([(task_routine.task_config.name,
task_routine.eval_steps)
for task_routine in params.eval_tasks])
evaluator = evaluator_lib.MultiTaskEvaluator(
eval_tasks=eval_tasks,
model=model,
global_step=trainer.global_step if is_training else None,
eval_steps=eval_steps,
checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir))
else:
evaluator = None
if trainer:
checkpoint = trainer.checkpoint
global_step = trainer.global_step
else:
checkpoint = evaluator.checkpoint
global_step = evaluator.global_step
checkpoint_manager = tf.train.CheckpointManager(
checkpoint,
directory=model_dir,
max_to_keep=params.trainer.max_to_keep,
step_counter=global_step,
checkpoint_interval=params.trainer.checkpoint_interval,
init_fn=trainer.initialize if trainer else None)
controller = orbit.Controller(
strategy=distribution_strategy,
trainer=trainer,
evaluator=evaluator,
global_step=global_step,
steps_per_loop=params.trainer.steps_per_loop,
checkpoint_manager=checkpoint_manager,
summary_dir=os.path.join(model_dir, 'train') if save_summary else None,
eval_summary_dir=os.path.join(model_dir, 'validation') if
(save_summary) else None,
eval_summary_manager=eval_summary_manager,
summary_interval=params.trainer.summary_interval if
(save_summary) else None)
logging.info('Starts to execute mode: %s', mode)
with distribution_strategy.scope():
if mode == 'train':
controller.train(steps=params.trainer.train_steps)
elif mode == 'train_and_eval':
controller.train_and_evaluate(
train_steps=params.trainer.train_steps,
eval_steps=params.trainer.validation_steps,
eval_interval=params.trainer.validation_interval)
elif mode == 'eval':
controller.evaluate(steps=params.trainer.validation_steps)
elif mode == 'continuous_eval':
def timeout_fn():
if evaluator.global_step.numpy() >= params.trainer.train_steps:
return True
return False
controller.evaluate_continuously(
steps=params.trainer.validation_steps,
timeout=params.trainer.continuous_eval_timeout,
timeout_fn=timeout_fn)
else:
raise NotImplementedError('The mode is not implemented: %s' % mode)
if run_post_eval:
return model, evaluator.evaluate(
tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras
else:
return model, {} # pytype: disable=bad-return-type # typed-keras