ASL-MoViNet-T5-translator / official /core /base_trainer_test.py
deanna-emery's picture
updates
93528c6
raw
history blame
13 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Tests for tensorflow_models.core.trainers.trainer."""
# pylint: disable=g-direct-tensorflow-import
import gc
import multiprocessing
import os
import sys
from absl.testing import parameterized
import orbit
import portpicker
import tensorflow as tf, tf_keras
from tensorflow.python.distribute import combinations
from tensorflow.python.distribute import strategy_combinations
from official.core import base_trainer as trainer_lib
from official.core import config_definitions as cfg
from official.core import train_lib
from official.utils.testing import mock_task
TPU_TEST = 'test_tpu' in sys.argv[0]
GPU_TEST = 'test_gpu' in sys.argv[0]
def all_strategy_combinations():
return combinations.combine(
distribution=[
strategy_combinations.default_strategy,
strategy_combinations.cloud_tpu_strategy,
strategy_combinations.one_device_strategy_gpu,
],)
def create_in_process_cluster(num_workers, num_ps):
"""Creates and starts local servers and returns the cluster_resolver."""
worker_ports = [portpicker.pick_unused_port() for _ in range(num_workers)]
ps_ports = [portpicker.pick_unused_port() for _ in range(num_ps)]
cluster_dict = {}
cluster_dict['worker'] = ['localhost:%s' % port for port in worker_ports]
if num_ps > 0:
cluster_dict['ps'] = ['localhost:%s' % port for port in ps_ports]
cluster_spec = tf.train.ClusterSpec(cluster_dict)
# Workers need some inter_ops threads to work properly.
worker_config = tf.compat.v1.ConfigProto()
if multiprocessing.cpu_count() < num_workers + 1:
worker_config.inter_op_parallelism_threads = num_workers + 1
for i in range(num_workers):
tf.distribute.Server(
cluster_spec,
job_name='worker',
task_index=i,
config=worker_config,
protocol='grpc')
for i in range(num_ps):
tf.distribute.Server(
cluster_spec, job_name='ps', task_index=i, protocol='grpc')
cluster_resolver = tf.distribute.cluster_resolver.SimpleClusterResolver(
cluster_spec, rpc_layer='grpc')
return cluster_resolver
def dataset_fn(input_context=None):
del input_context
def dummy_data(_):
return tf.zeros((1, 1), dtype=tf.float32)
dataset = tf.data.Dataset.range(1)
dataset = dataset.repeat()
dataset = dataset.map(
dummy_data, num_parallel_calls=tf.data.experimental.AUTOTUNE)
return dataset
class MockAsyncTrainer(trainer_lib._AsyncTrainer):
"""Mock AsyncTrainer to test the _AsyncTrainer class."""
def __init__(self):
self._strategy = tf.distribute.get_strategy()
self.init_async()
self.global_step = tf.Variable(
0,
dtype=tf.int64,
name='global_step',
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
self.eval_global_step = tf.Variable(
0,
dtype=tf.int64,
name='eval_global_step',
trainable=False,
aggregation=tf.VariableAggregation.ONLY_FIRST_REPLICA)
train_dataset = self.distribute_dataset(dataset_fn)
orbit.StandardTrainer.__init__(
self, train_dataset, options=orbit.StandardTrainerOptions())
validation_dataset = self.distribute_dataset(dataset_fn)
orbit.StandardEvaluator.__init__(
self,
validation_dataset,
options=orbit.StandardEvaluatorOptions(use_tf_while_loop=True))
def train_loop_begin(self):
self.global_step.assign(0)
def train_step(self, iterator):
def replica_step(_):
self.global_step.assign_add(1)
self._strategy.run(replica_step, args=(next(iterator),))
def train_loop_end(self):
self.join()
return self.global_step.numpy()
def eval_begin(self):
self.eval_global_step.assign(0)
def eval_step(self, iterator):
def replica_step(_):
self.eval_global_step.assign_add(1)
self._strategy.run(replica_step, args=(next(iterator),))
def eval_end(self):
self.join()
return self.eval_global_step.numpy()
class TrainerTest(tf.test.TestCase, parameterized.TestCase):
def setUp(self):
super().setUp()
self._config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
def tearDown(self):
gc.collect()
# This will only contain uncollectable garbage, i.e. reference cycles
# involving objects with __del__ defined.
self.assertEmpty(gc.garbage)
super().tearDown()
def create_test_trainer(self, config, model_dir=None, task=None):
task = task or mock_task.MockTask(config.task, logging_dir=model_dir)
ckpt_exporter = train_lib.maybe_create_best_ckpt_exporter(config, model_dir)
trainer = trainer_lib.Trainer(
config,
task,
model=task.build_model(),
optimizer=task.create_optimizer(config.trainer.optimizer_config,
config.runtime),
checkpoint_exporter=ckpt_exporter)
return trainer
@combinations.generate(all_strategy_combinations())
def test_trainer_train(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(self._config)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_passing_datasets(self, distribution):
with distribution.scope():
task = mock_task.MockTask(self._config)
train_dataset = orbit.utils.make_distributed_dataset(
distribution, task.build_inputs, self._config.task.train_data)
validation_dataset = orbit.utils.make_distributed_dataset(
distribution, task.build_inputs, self._config.task.validation_data)
self._config.task.train_data = None
self._config.task.validation_data = None
trainer = trainer_lib.Trainer(
self._config,
task,
model=task.build_model(),
optimizer=task.create_optimizer(self._config.trainer.optimizer_config,
self._config.runtime),
train_dataset=train_dataset,
validation_dataset=validation_dataset)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('validation_loss', logs)
def test_base_async_trainer(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
trainer = MockAsyncTrainer()
trainer.init_async()
self.assertIsInstance(
trainer._coordinator,
tf.distribute.experimental.coordinator.ClusterCoordinator)
self.assertEqual(trainer.train(tf.constant(10)), 10)
self.assertEqual(trainer.evaluate(tf.constant(11)), 11)
def test_async_trainer_train(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/TPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
config = cfg.ExperimentConfig(**self._config.as_dict())
config.trainer.eval_tf_while_loop = True
trainer = self.create_test_trainer(config)
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
self.assertIn('learning_rate', logs)
def test_async_trainer_validate(self):
if TPU_TEST or GPU_TEST:
self.skipTest('Aysnc training is not available on GPU/GPU.')
num_workers = 3
num_ps = 2
cluster_resolver = create_in_process_cluster(num_workers, num_ps)
distribution = tf.distribute.experimental.ParameterServerStrategy(
cluster_resolver)
with distribution.scope():
config = cfg.ExperimentConfig(**self._config.as_dict())
config.trainer.eval_tf_while_loop = True
trainer = self.create_test_trainer(config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('acc', logs)
self.assertIn('validation_loss', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate(self, distribution):
with distribution.scope():
trainer = self.create_test_trainer(self._config)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
self.assertIn('validation_loss', logs)
@combinations.generate(all_strategy_combinations())
def test_trainer_validate_without_loss(self, distribution):
class MockTaskWithoutValidationLoss(mock_task.MockTask):
def validation_step(self, inputs, model, metrics=None):
# Disable validation loss.
logs = super().validation_step(inputs, model)
del logs[self.loss]
return logs
with distribution.scope():
task = MockTaskWithoutValidationLoss()
trainer = self.create_test_trainer(self._config, task=task)
logs = trainer.evaluate(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertEqual(logs['counter'], 5. * distribution.num_replicas_in_sync)
self.assertNotIn('validation_loss', logs)
@combinations.generate(
combinations.combine(
mixed_precision_dtype=['float32', 'bfloat16', 'float16'],
loss_scale=[None, 'dynamic', 128, 256],
))
def test_configure_optimizer(self, mixed_precision_dtype, loss_scale):
config = cfg.ExperimentConfig(
runtime=cfg.RuntimeConfig(
mixed_precision_dtype=mixed_precision_dtype, loss_scale=loss_scale),
trainer=cfg.TrainerConfig(
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
},
})))
trainer = self.create_test_trainer(config)
if mixed_precision_dtype == 'float16':
self.assertIsInstance(trainer.optimizer,
tf_keras.mixed_precision.LossScaleOptimizer)
if loss_scale in (None, 'dynamic'):
self.assertTrue(trainer.optimizer.dynamic)
else:
self.assertFalse(trainer.optimizer.dynamic)
self.assertEqual(trainer.optimizer.initial_scale, loss_scale)
else:
self.assertIsInstance(
trainer.optimizer,
(tf_keras.optimizers.SGD, tf_keras.optimizers.legacy.SGD))
metrics = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', metrics)
def test_export_best_ckpt(self):
config = cfg.ExperimentConfig(
trainer=cfg.TrainerConfig(
best_checkpoint_export_subdir='best_ckpt',
best_checkpoint_eval_metric='acc',
optimizer_config=cfg.OptimizationConfig({
'optimizer': {
'type': 'sgd'
},
'learning_rate': {
'type': 'constant'
}
})))
model_dir = self.get_temp_dir()
trainer = self.create_test_trainer(config, model_dir=model_dir)
trainer.train(tf.convert_to_tensor(1, dtype=tf.int32))
trainer.evaluate(tf.convert_to_tensor(1, dtype=tf.int32))
self.assertTrue(
tf.io.gfile.exists(os.path.join(model_dir, 'best_ckpt', 'info.json')))
def test_model_with_compiled_loss(self):
task = mock_task.MockTask()
model = task.build_model()
model.compile(loss=tf_keras.losses.CategoricalCrossentropy())
trainer = trainer_lib.Trainer(
self._config,
task,
model=model,
optimizer=task.create_optimizer(self._config.trainer.optimizer_config))
logs = trainer.train(tf.convert_to_tensor(5, dtype=tf.int32))
self.assertIn('training_loss', logs)
if __name__ == '__main__':
tf.test.main()