File size: 11,303 Bytes
5672777
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
#     http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

"""Progressive Trainer implementation.

The trainer implements the Orbit `StandardTrainable` and
`StandardEvaluable` interfaces. Trainers inside this project should be
interchangable and independent on model architectures and tasks.
"""

import dataclasses
import os
from typing import Any, Optional

# Import libraries
from absl import logging
import gin
import orbit
import tensorflow as tf, tf_keras
from official.core import base_task
from official.core import base_trainer as trainer_lib
from official.core import config_definitions
from official.modeling.fast_training.progressive import policies
from official.modeling.fast_training.progressive import utils

ExperimentConfig = config_definitions.ExperimentConfig


@dataclasses.dataclass
class ProgressiveTrainerConfig(config_definitions.TrainerConfig):
  """Configuration for progressive trainer.

  Attributes:
    progressive: A task-specific config. Users can subclass ProgressiveConfig
      and define any task-specific settings in their subclass.
    export_checkpoint: A bool. Whether to export checkpoints in non-progressive
      manner (without the volatiles wrapper) such that your down-stream tasks
      can load checkpoints from a progressive trainer as if it is a regular
      checkpoint.
    export_checkpoint_interval: A bool. The number of steps between exporting
      checkpoints. If None (by default), will use the same value as
      TrainerConfig.checkpoint_interval.
    export_max_to_keep: The maximum number of exported checkpoints to keep.
      If None (by default), will use the same value as
      TrainerConfig.max_to_keep.
    export_only_final_stage_ckpt: A bool. Whether to just export checkpoints
      during the final progressive training stage. In other words, whether to
      not export small, partial models. In many cases, it is not meaningful to
      finetune a small, partial model in down-stream tasks.
  """
  progressive: Optional[policies.ProgressiveConfig] = None
  export_checkpoint: bool = True
  export_checkpoint_interval: Optional[int] = None
  export_max_to_keep: Optional[int] = None
  export_only_final_stage_ckpt: bool = True


@gin.configurable
class ProgressiveTrainer(trainer_lib.Trainer):
  """Implements the progressive trainer shared for TensorFlow models."""

  def __init__(
      self,
      config: ExperimentConfig,
      prog_task: base_task.Task,  # also implemented ProgressivePolicy.
      ckpt_dir: str = '',
      train: bool = True,
      evaluate: bool = True,
      checkpoint_exporter: Any = None):
    """Initialize common trainer for TensorFlow models.

    Args:
      config: An `ExperimentConfig` instance specifying experiment config.
      prog_task: An instance both implemented policies.ProgressivePolicy and
        base_task.Task.
      ckpt_dir: Checkpoint directory.
      train: bool, whether or not this trainer will be used for training.
        default to True.
      evaluate: bool, whether or not this trainer will be used for evaluation.
        default to True.
      checkpoint_exporter: an object that has the `maybe_export_checkpoint`
        interface.
    """
    # Gets the current distribution strategy. If not inside any strategy scope,
    # it gets a single-replica no-op strategy.
    self._strategy = tf.distribute.get_strategy()
    self._config = config
    self._runtime_options = trainer_lib.get_runtime_options(config)
    self._task = prog_task

    # Directory for non-progressive checkpoint
    self._export_ckpt_dir = os.path.join(ckpt_dir, 'exported_ckpts')
    tf.io.gfile.makedirs(self._export_ckpt_dir)
    self._export_ckpt_manager = None

    # Receive other checkpoint export, e.g, best checkpoint exporter.
    # TODO(lehou): unify the checkpoint exporting logic, although the default
    # setting does not use checkpoint_exporter.
    self._checkpoint_exporter = checkpoint_exporter

    self._global_step = orbit.utils.create_global_step()

    self._checkpoint = utils.CheckpointWithHooks(
        before_load_hook=self._update_pt_stage_from_ckpt,
        global_step=self.global_step,
        **self._task.cur_checkpoint_items)

    self._train_loss = tf_keras.metrics.Mean('training_loss', dtype=tf.float32)
    self._validation_loss = tf_keras.metrics.Mean(
        'validation_loss', dtype=tf.float32)
    self._train_metrics = self.task.build_metrics(
        training=True) + self.model.metrics
    self._validation_metrics = self.task.build_metrics(
        training=False) + self.model.metrics

    if train:
      orbit.StandardTrainer.__init__(
          self,
          None,  # Manage train_dataset by ourselves, not by StandardTrainer.
          options=orbit.StandardTrainerOptions(
              use_tf_while_loop=config.trainer.train_tf_while_loop,
              use_tf_function=config.trainer.train_tf_function))

    if evaluate:
      orbit.StandardEvaluator.__init__(
          self,
          None,  # Manage train_dataset by ourselves, not by StandardEvaluator.
          options=orbit.StandardEvaluatorOptions(
              use_tf_function=config.trainer.eval_tf_function))

  @property
  def model(self):
    return self._task.cur_model

  @property
  def optimizer(self):
    return self._task.cur_optimizer

  # override
  @property
  def train_dataset(self):
    """Overriding StandardTrainer.train_dataset."""
    return self._task.cur_train_dataset

  # override
  @train_dataset.setter
  def train_dataset(self, _):
    raise SyntaxError('Please do not set train_dataset. Progressive training '
                      'relies on progressive policy to manager train dataset.')

  # override
  @property
  def eval_dataset(self):
    """Overriding StandardEvaluator.eval_dataset."""
    return self._task.cur_eval_dataset

  # override
  @eval_dataset.setter
  def eval_dataset(self, _):
    raise SyntaxError('Please do not set eval_dataset. Progressive training '
                      'relies on progressive policy to manager eval dataset.')

  def train_loop_end(self):
    """See base class."""
    logs = {}
    for metric in self.train_metrics + [self.train_loss]:
      logs[metric.name] = metric.result()
      metric.reset_states()
    if callable(self.optimizer.learning_rate):
      logs['learning_rate'] = self.optimizer.learning_rate(
          self.optimizer.iterations)
    else:
      logs['learning_rate'] = self.optimizer.learning_rate

    self._maybe_export_non_progressive_checkpoint(self._export_ckpt_dir)
    if self._task.is_stage_advancing(self.global_step.numpy()):
      old_train_dataset = self.train_dataset

      # Update progressive properties
      self._task.update_pt_stage(self.global_step.numpy())

      # Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
      # rebuild the train and eval functions with the updated model.
      self._train_loop_fn = None
      self._eval_loop_fn = None

      if self.train_dataset != old_train_dataset:
        # Setting `self._train_iter` to None will rebuild the dataset iterator.
        self._train_iter = None

      # Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
      # for exporting.
      self._export_ckpt_manager = None

    return logs

  def _update_pt_stage_from_ckpt(self, ckpt_file):
    """Update stage properties based on the global_step variable in a ckpt file.

    Before loading variables from a checkpoint file, we need to go to the
    correct stage and build corresponding model and optimizer, to make sure that
    we retore variables of the right model and optimizer.

    Args:
      ckpt_file: Checkpoint file that will be restored/read from.
    """
    if not ckpt_file:
      return
    ckpt = tf.train.Checkpoint(global_step=self.global_step)
    ckpt.read(ckpt_file).expect_partial().assert_existing_objects_matched()

    if self._task.is_stage_advancing(self.global_step.numpy()):
      old_train_dataset = self.train_dataset

      # Update progressive properties
      self._task.update_pt_stage(self.global_step.numpy(), pass_old_model=False)

      # Setting `self._train_loop_fn` and `self._eval_loop_fn` to None will
      # rebuild the train and eval functions with the updated model.
      self._train_loop_fn = None
      self._eval_loop_fn = None

      if self.train_dataset != old_train_dataset:
        # Setting `self._train_iter` to None will rebuild the dataset iterator.
        self._train_iter = None

      # Setting `self._export_ckpt_manager` to None will rebuild the checkpoint
      # for exporting.
      self._export_ckpt_manager = None

  def _maybe_export_non_progressive_checkpoint(self, export_ckpt_dir):
    """Export checkpoints in non-progressive format.

    This basically removes the wrapping of self._task.cur_checkpoint_items
    -- just save the model, optimizer, etc., directly.
    The purpose is to let your down-stream tasks to use these checkpoints.

    Args:
      export_ckpt_dir: A str. folder of exported checkpoints.
    """
    if not self.config.trainer.export_checkpoint:
      logging.info('Not exporting checkpoints.')
      return
    if not self._task.is_last_stage and (
        self.config.trainer.export_only_final_stage_ckpt):
      logging.info('Not exporting checkpoints until the last stage.')
      return

    if self._export_ckpt_manager is None:
      # Create a checkpoint object just now, to make sure we use
      # progressive_policy.cur_model and progressive_policy.cur_optimizer of the
      # current stage.
      if hasattr(self.model, 'checkpoint_items'):
        checkpoint_items = self.model.checkpoint_items
      else:
        checkpoint_items = {}
      checkpoint = tf.train.Checkpoint(
          global_step=self.global_step,
          model=self.model,
          optimizer=self.optimizer,
          **checkpoint_items)

      max_to_keep = self.config.trainer.export_max_to_keep or (
          self.config.trainer.max_to_keep)
      checkpoint_interval = self.config.trainer.export_checkpoint_interval or (
          self.config.trainer.checkpoint_interval)
      self._export_ckpt_manager = tf.train.CheckpointManager(
          checkpoint,
          directory=export_ckpt_dir,
          checkpoint_name='ckpt',
          step_counter=self.global_step,
          max_to_keep=max_to_keep,
          checkpoint_interval=checkpoint_interval,
      )

    # Make sure we export the last checkpoint.
    last_checkpoint = (
        self.global_step.numpy() == self._config.trainer.train_steps)
    checkpoint_path = self._export_ckpt_manager.save(
        checkpoint_number=self.global_step.numpy(),
        check_interval=not last_checkpoint)
    if checkpoint_path:
      logging.info('Checkpoints exported: %s.', checkpoint_path)