# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Multitask training driver library.""" # pytype: disable=attribute-error import os from typing import Any, List, Mapping, Optional, Tuple, Union from absl import logging import orbit import tensorflow as tf, tf_keras from official.core import base_task from official.core import base_trainer as core_lib from official.core import train_utils from official.modeling.multitask import base_model from official.modeling.multitask import base_trainer from official.modeling.multitask import configs from official.modeling.multitask import evaluator as evaluator_lib from official.modeling.multitask import interleaving_trainer from official.modeling.multitask import multitask from official.modeling.multitask import task_sampler TRAINERS = { 'interleaving': interleaving_trainer.MultiTaskInterleavingTrainer, 'joint': base_trainer.MultiTaskBaseTrainer } def run_experiment( *, distribution_strategy: tf.distribute.Strategy, task: multitask.MultiTask, model: base_model.MultiTaskBaseModel, mode: str, params: configs.MultiTaskExperimentConfig, model_dir: str, run_post_eval: bool = False, trainer: base_trainer.MultiTaskBaseTrainer = None, eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None, best_ckpt_exporter_creator: Optional[Any] = train_utils .maybe_create_best_ckpt_exporter ) -> Union[base_model.MultiTaskBaseModel, Tuple[base_model.MultiTaskBaseModel, Mapping[Any, Any]]]: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. task: A MultiTaskTask instance. model: A MultiTaskBaseModel instance. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: ExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. trainer: (optional) A multi-task trainer to use. If none is provided, a default one will be created based on `params`. eval_summary_manager: Instance of the eval summary manager. If set, the `eval_summary_dir` will be ignored. Otherwise the eval summary manager will be created internally for TensorBoard summaries by default from the `eval_summary_dir`. best_ckpt_exporter_creator: A functor for creating best checkpoint exporter. Returns: model: `base_model.MultiTaskBaseModel` instance. """ is_training = 'train' in mode is_eval = 'eval' in mode with distribution_strategy.scope(): optimizer = train_utils.create_optimizer(task, params) kwargs = dict(multi_task=task, multi_task_model=model, optimizer=optimizer) if params.trainer.trainer_type == 'interleaving': sampler = task_sampler.get_task_sampler(params.trainer.task_sampler, task.task_weights) kwargs.update(dict(task_sampler=sampler)) if trainer is None: trainer = TRAINERS[params.trainer.trainer_type]( **kwargs) if is_training else None if is_eval: eval_steps = task.task_eval_steps evaluator = evaluator_lib.MultiTaskEvaluator( eval_tasks=task.tasks.values(), model=model, eval_steps=eval_steps, global_step=trainer.global_step if is_training else None, checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir)) else: evaluator = None if trainer: checkpoint = trainer.checkpoint global_step = trainer.global_step else: checkpoint = evaluator.checkpoint global_step = evaluator.global_step checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=model.initialize) controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer, evaluator=evaluator, global_step=global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train'), eval_summary_dir=os.path.join(model_dir, 'validation'), eval_summary_manager=eval_summary_manager, summary_interval=params.trainer.summary_interval) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if evaluator.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) if run_post_eval: return model, evaluator.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras else: return model def run_experiment_with_multitask_eval( *, distribution_strategy: tf.distribute.Strategy, train_task: base_task.Task, eval_tasks: List[base_task.Task], mode: str, params: configs.MultiEvalExperimentConfig, model_dir: str, run_post_eval: bool = False, save_summary: bool = True, trainer: Optional[core_lib.Trainer] = None, eval_summary_manager: Optional[orbit.utils.SummaryManagerInterface] = None, best_ckpt_exporter_creator: Optional[Any] = train_utils .maybe_create_best_ckpt_exporter, ) -> Tuple[Any, Any]: """Runs train/eval configured by the experiment params. Args: distribution_strategy: A distribution distribution_strategy. train_task: A base_task.Task instance. eval_tasks: A list of evaluation tasks. mode: A 'str', specifying the mode. Can be 'train', 'eval', 'train_and_eval' or 'continuous_eval'. params: MultiEvalExperimentConfig instance. model_dir: A 'str', a path to store model checkpoints and summaries. run_post_eval: Whether to run post eval once after training, metrics logs are returned. save_summary: Whether to save train and validation summary. trainer: the core_lib.Trainer instance. It should be created within the strategy.scope(). If not provided, an instance will be created by default if `mode` contains 'train'. eval_summary_manager: Instance of the eval summary manager. If set, the `eval_summary_dir` will be ignored. Otherwise the eval summary manager will be created internally for TensorBoard summaries by default from the `eval_summary_dir`. best_ckpt_exporter_creator: A functor for creating best checkpoint exporter. Returns: model: `tf_keras.Model` instance. """ is_training = 'train' in mode is_eval = 'eval' in mode with distribution_strategy.scope(): if is_training: trainer = trainer or core_lib.Trainer( config=params, task=train_task, model=train_task.build_model(), optimizer=train_utils.create_optimizer(train_task, params), train=True, evaluate=False) else: trainer = None # Build the model or fetch the pre-cached one (which could be either # multi-task model or single task model). model = None if trainer is None: if isinstance(train_task, multitask.MultiTask): model = train_task.build_multitask_model() else: model = train_task.build_model() else: if isinstance(trainer, base_trainer.MultiTaskBaseTrainer): model = trainer.multi_task_model else: model = trainer.model if is_eval: eval_steps = dict([(task_routine.task_config.name, task_routine.eval_steps) for task_routine in params.eval_tasks]) evaluator = evaluator_lib.MultiTaskEvaluator( eval_tasks=eval_tasks, model=model, global_step=trainer.global_step if is_training else None, eval_steps=eval_steps, checkpoint_exporter=best_ckpt_exporter_creator(params, model_dir)) else: evaluator = None if trainer: checkpoint = trainer.checkpoint global_step = trainer.global_step else: checkpoint = evaluator.checkpoint global_step = evaluator.global_step checkpoint_manager = tf.train.CheckpointManager( checkpoint, directory=model_dir, max_to_keep=params.trainer.max_to_keep, step_counter=global_step, checkpoint_interval=params.trainer.checkpoint_interval, init_fn=trainer.initialize if trainer else None) controller = orbit.Controller( strategy=distribution_strategy, trainer=trainer, evaluator=evaluator, global_step=global_step, steps_per_loop=params.trainer.steps_per_loop, checkpoint_manager=checkpoint_manager, summary_dir=os.path.join(model_dir, 'train') if save_summary else None, eval_summary_dir=os.path.join(model_dir, 'validation') if (save_summary) else None, eval_summary_manager=eval_summary_manager, summary_interval=params.trainer.summary_interval if (save_summary) else None) logging.info('Starts to execute mode: %s', mode) with distribution_strategy.scope(): if mode == 'train': controller.train(steps=params.trainer.train_steps) elif mode == 'train_and_eval': controller.train_and_evaluate( train_steps=params.trainer.train_steps, eval_steps=params.trainer.validation_steps, eval_interval=params.trainer.validation_interval) elif mode == 'eval': controller.evaluate(steps=params.trainer.validation_steps) elif mode == 'continuous_eval': def timeout_fn(): if evaluator.global_step.numpy() >= params.trainer.train_steps: return True return False controller.evaluate_continuously( steps=params.trainer.validation_steps, timeout=params.trainer.continuous_eval_timeout, timeout_fn=timeout_fn) else: raise NotImplementedError('The mode is not implemented: %s' % mode) if run_post_eval: return model, evaluator.evaluate( tf.convert_to_tensor(params.trainer.validation_steps)) # pytype: disable=bad-return-type # typed-keras else: return model, {} # pytype: disable=bad-return-type # typed-keras