# Copyright 2023 The TensorFlow Authors. All Rights Reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. """Progressive Trainer implementation. The trainer implements the Orbit `StandardTrainable` and `StandardEvaluable` interfaces. Trainers inside this project should be interchangable and independent on model architectures and tasks. """ import dataclasses import os from typing import Any, Optional # Import libraries from absl import logging import gin import orbit import tensorflow as tf, tf_keras from official.core import base_task from official.core import base_trainer as trainer_lib from official.core import config_definitions from official.modeling.fast_training.progressive import policies from official.modeling.fast_training.progressive import utils ExperimentConfig = config_definitions.ExperimentConfig @dataclasses.dataclass class ProgressiveTrainerConfig(config_definitions.TrainerConfig): """Configuration for progressive trainer. Attributes: progressive: A task-specific config. Users can subclass ProgressiveConfig and define any task-specific settings in their subclass. export_checkpoint: A bool. Whether to export checkpoints in non-progressive manner (without the volatiles wrapper) such that your down-stream tasks can load checkpoints from a progressive trainer as if it is a regular checkpoint. export_checkpoint_interval: A bool. The number of steps between exporting checkpoints. If None (by default), will use the same value as TrainerConfig.checkpoint_interval. export_max_to_keep: The maximum number of exported checkpoints to keep. If None (by default), will use the same value as TrainerConfig.max_to_keep. export_only_final_stage_ckpt: A bool. Whether to just export checkpoints during the final progressive training stage. In other words, whether to not export small, partial models. In many cases, it is not meaningful to finetune a small, partial model in down-stream tasks. """ progressive: Optional[policies.ProgressiveConfig] = None export_checkpoint: bool = True export_checkpoint_interval: Optional[int] = None export_max_to_keep: Optional[int] = None export_only_final_stage_ckpt: bool = True @gin.configurable class ProgressiveTrainer(trainer_lib.Trainer): """Implements the progressive trainer shared for TensorFlow models.""" def __init__( self, config: ExperimentConfig, prog_task: base_task.Task, # also implemented ProgressivePolicy. ckpt_dir: str = '', train: bool = True, evaluate: bool = True, checkpoint_exporter: Any = None): """Initialize common trainer for TensorFlow models. Args: config: An `ExperimentConfig` instance specifying experiment config. prog_task: An instance both implemented policies.ProgressivePolicy and base_task.Task. ckpt_dir: Checkpoint directory. train: bool, whether or not this trainer will be used for training. default to True. evaluate: bool, whether or not this trainer will be used for evaluation. default to True. checkpoint_exporter: an object that has the `maybe_export_checkpoint` interface. """ # Gets the current distribution strategy. If not inside any strategy scope, # it gets a single-replica no-op strategy. self._strategy = tf.distribute.get_strategy() self._config = config self._runtime_options = trainer_lib.get_runtime_options(config) self._task = prog_task # Directory for non-progressive checkpoint self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts') tf.io.gfile.makedirs(self._export_ckpt_dir) self._export_ckpt_manager = None # Receive other checkpoint export, e.g, best checkpoint exporter. # TODO(lehou): unify the checkpoint exporting logic, although the default # setting does not use checkpoint_exporter. self._checkpoint_exporter = checkpoint_exporter self._global_step = orbit.utils.create_global_step() self._checkpoint = utils.CheckpointWithHooks( before_load_hook=self._update_pt_stage_from_ckpt, global_step=self.global_step, **self._task.cur_checkpoint_items) self._train_loss = tf_keras.metrics.Mean('training_loss', dtype=tf.float32) self._validation_loss = tf_keras.metrics.Mean( 'validation_loss', dtype=tf.float32) self._train_metrics = self.task.build_metrics( training=True) + self.model.metrics self._validation_metrics = self.task.build_metrics( training=False) + self.model.metrics if train: orbit.StandardTrainer.__init__( self, None, # Manage train_dataset by ourselves, not by StandardTrainer. options=orbit.StandardTrainerOptions( use_tf_while_loop=config.trainer.train_tf_while_loop, use_tf_function=config.trainer.train_tf_function)) if evaluate: orbit.StandardEvaluator.__init__( self, None, # Manage train_dataset by ourselves, not by StandardEvaluator. options=orbit.StandardEvaluatorOptions( use_tf_function=config.trainer.eval_tf_function)) @property def model(self): return self._task.cur_model @property def optimizer(self): return self._task.cur_optimizer # override @property def train_dataset(self): """Overriding StandardTrainer.train_dataset.""" return self._task.cur_train_dataset # override @train_dataset.setter def train_dataset(self, _): raise SyntaxError('Please do not set train_dataset. Progressive training ' 'relies on progressive policy to manager train dataset.') # override @property def eval_dataset(self): """Overriding StandardEvaluator.eval_dataset.""" return self._task.cur_eval_dataset # override @eval_dataset.setter def eval_dataset(self, _): raise SyntaxError('Please do not set eval_dataset. Progressive training ' 'relies on progressive policy to manager eval dataset.') def train_loop_end(self): """See base class.""" logs = {} for metric in self.train_metrics + [self.train_loss]: logs[metric.name] = metric.result() metric.reset_states() if callable(self.optimizer.learning_rate): logs['learning_rate'] = self.optimizer.learning_rate( self.optimizer.iterations) else: logs['learning_rate'] = self.optimizer.learning_rate self._maybe_export_non_progressive_checkpoint(self._export_ckpt_dir) if self._task.is_stage_advancing(self.global_step.numpy()): old_train_dataset = self.train_dataset # Update progressive properties self._task.update_pt_stage(self.global_step.numpy()) # Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will # rebuild the train and eval functions with the updated model. self._train_loop_fn = None self._eval_loop_fn = None if self.train_dataset != old_train_dataset: # Setting `self._train_iter` to None will rebuild the dataset iterator. self._train_iter = None # Setting `self._export_ckpt_manager` to None will rebuild the checkpoint # for exporting. self._export_ckpt_manager = None return logs def _update_pt_stage_from_ckpt(self, ckpt_file): """Update stage properties based on the global_step variable in a ckpt file. Before loading variables from a checkpoint file, we need to go to the correct stage and build corresponding model and optimizer, to make sure that we retore variables of the right model and optimizer. Args: ckpt_file: Checkpoint file that will be restored/read from. """ if not ckpt_file: return ckpt = tf.train.Checkpoint(global_step=self.global_step) ckpt.read(ckpt_file).expect_partial().assert_existing_objects_matched() if self._task.is_stage_advancing(self.global_step.numpy()): old_train_dataset = self.train_dataset # Update progressive properties self._task.update_pt_stage(self.global_step.numpy(), pass_old_model=False) # Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will # rebuild the train and eval functions with the updated model. self._train_loop_fn = None self._eval_loop_fn = None if self.train_dataset != old_train_dataset: # Setting `self._train_iter` to None will rebuild the dataset iterator. self._train_iter = None # Setting `self._export_ckpt_manager` to None will rebuild the checkpoint # for exporting. self._export_ckpt_manager = None def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir): """Export checkpoints in non-progressive format. This basically removes the wrapping of self._task.cur_checkpoint_items -- just save the model, optimizer, etc., directly. The purpose is to let your down-stream tasks to use these checkpoints. Args: export_ckpt_dir: A str. folder of exported checkpoints. """ if not self.config.trainer.export_checkpoint: logging.info('Not exporting checkpoints.') return if not self._task.is_last_stage and ( self.config.trainer.export_only_final_stage_ckpt): logging.info('Not exporting checkpoints until the last stage.') return if self._export_ckpt_manager is None: # Create a checkpoint object just now, to make sure we use # progressive_policy.cur_model and progressive_policy.cur_optimizer of the # current stage. if hasattr(self.model, 'checkpoint_items'): checkpoint_items = self.model.checkpoint_items else: checkpoint_items = {} checkpoint = tf.train.Checkpoint( global_step=self.global_step, model=self.model, optimizer=self.optimizer, **checkpoint_items) max_to_keep = self.config.trainer.export_max_to_keep or ( self.config.trainer.max_to_keep) checkpoint_interval = self.config.trainer.export_checkpoint_interval or ( self.config.trainer.checkpoint_interval) self._export_ckpt_manager = tf.train.CheckpointManager( checkpoint, directory=export_ckpt_dir, checkpoint_name='ckpt', step_counter=self.global_step, max_to_keep=max_to_keep, checkpoint_interval=checkpoint_interval, ) # Make sure we export the last checkpoint. last_checkpoint = ( self.global_step.numpy() == self._config.trainer.train_steps) checkpoint_path = self._export_ckpt_manager.save( checkpoint_number=self.global_step.numpy(), check_interval=not last_checkpoint) if checkpoint_path: logging.info('Checkpoints exported: %s.', checkpoint_path)