deanna-emery's picture
updates
93528c6
raw
history blame
9.3 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for the progressive trainer."""
# pylint: disable=g-direct-tensorflow-import
import os
from absl.testing import parameterized
import orbit
import tensorflow as tf, tf_keras
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import config_definitions as cfg
from official.modeling import optimization
from official.modeling.fast_training.progressive import policies
from official.modeling.fast_training.progressive import trainer as trainer_lib
from official.nlp.configs import bert
from official.utils.testing import mock_task
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
def get_exp_config():
return cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
class TestPolicy(policies.ProgressivePolicy, mock_task.MockTask):
"""Just for testing purposes."""
def __init__(self, strategy, task_config, change_train_dataset=True):
self._strategy = strategy
self._change_train_dataset = change_train_dataset
self._my_train_dataset = None
mock_task.MockTask.__init__(self, params=task_config, logging_dir=None)
policies.ProgressivePolicy.__init__(self)
def num_stages(self) -> int:
return 2
def num_steps(self, stage_id: int) -> int:
return 2 if stage_id == 0 else 4
def get_model(self,
stage_id: int,
old_model: tf_keras.Model) -> tf_keras.Model:
del stage_id, old_model
return self.build_model()
def get_optimizer(self, stage_id: int) -> tf_keras.optimizers.Optimizer:
optimizer_type = 'sgd' if stage_id == 0 else 'adamw'
optimizer_config = cfg.OptimizationConfig({
'optimizer': {'type': optimizer_type},
'learning_rate': {'type': 'constant'}})
opt_factory = optimization.OptimizerFactory(optimizer_config)
return opt_factory.build_optimizer(opt_factory.build_learning_rate())
def get_train_dataset(self, stage_id: int) -> tf.data.Dataset:
if not self._change_train_dataset and self._my_train_dataset:
return self._my_train_dataset
if self._strategy:
self._my_train_dataset = orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
else:
self._my_train_dataset = self._build_inputs(stage_id)
return self._my_train_dataset
def get_eval_dataset(self, stage_id: int) -> tf.data.Dataset:
if self._strategy:
return orbit.utils.make_distributed_dataset(
self._strategy,
self._build_inputs,
stage_id)
return self._build_inputs(stage_id)
def _build_inputs(self, stage_id):
def dummy_data(_):
batch_size = 2 if stage_id == 0 else 1
x = tf.zeros(shape=(batch_size, 2), dtype=tf.float32)
label = tf.zeros(shape=(batch_size, 1), dtype=tf.float32)
return x, label
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
return dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution, model_dir, change_train_dataset):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(
distribution, self._config.task, change_train_dataset),
ckpt_dir=model_dir)
return trainer
@combinations.generate(all_strategy_combinations())
def test_checkpointing(self, distribution):
model_dir = self.get_temp_dir()
ckpt_file = os.path.join(model_dir, 'ckpt')
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
self.assertTrue(trainer._task.is_last_stage)
trainer.checkpoint.save(ckpt_file)
trainer = self.create_test_trainer(distribution, model_dir, True)
self.assertFalse(trainer._task.is_last_stage)
trainer.checkpoint.restore(ckpt_file + '-1')
self.assertTrue(trainer._task.is_last_stage)
@combinations.generate(all_strategy_combinations())
def test_train_dataset(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
# Using dataset of stage == 0
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 2)
trainer.train(tf.convert_to_tensor(4, dtype=tf.int32))
# Using dataset of stage == 1
train_iter = tf.nest.map_structure(iter, trainer.train_dataset)
train_data = train_iter.next()[0]
if distribution.num_replicas_in_sync > 1:
train_data = train_data.values[0]
self.assertEqual(train_data.shape[0], 1)
with self.assertRaises(SyntaxError):
trainer.train_dataset = None
@combinations.generate(all_strategy_combinations())
def test_train_dataset_no_switch(self, distribution):
model_dir = self.get_temp_dir()
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, False)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is not reset since the dataset is not changed.
self.assertIsNotNone(trainer._train_iter)
with distribution.scope():
trainer = self.create_test_trainer(distribution, model_dir, True)
trainer.train(tf.convert_to_tensor(2, dtype=tf.int32))
# _train_iter is reset since the dataset changed.
self.assertIsNone(trainer._train_iter)
class TrainerWithMaskedLMTaskTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super(TrainerWithMaskedLMTaskTest, self).setUp()
self._config = get_exp_config()
def create_test_trainer(self, distribution):
trainer = trainer_lib.ProgressiveTrainer(
self._config,
prog_task=TestPolicy(distribution, self._config.task),
ckpt_dir=self.get_temp_dir())
return trainer
@combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(distribution)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
@combinations.generate(
combinations.combine(
mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
loss_scale=[None, 'dynamic', 128, 256],
))
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
config = cfg.ExperimentConfig(
task=cfg.TaskConfig(
model=bert.PretrainerConfig()),
runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=trainer_lib.ProgressiveTrainerConfig(
export_checkpoint=True,
export_checkpoint_interval=1,
export_only_final_stage_ckpt=False))
task = TestPolicy(None, config.task)
trainer = trainer_lib.ProgressiveTrainer(config, task, self.get_temp_dir())
if mixed_precision_dtype != 'float16':
self.assertIsInstance(
trainer.optimizer,
(tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))
elif mixed_precision_dtype == 'float16' and loss_scale is None:
self.assertIsInstance(
trainer.optimizer,
(tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
if __name__ == '__main__':
tf.test.main()