deanna-emery's picture
updates
93528c6
raw
history blame
6.64 kB
# Copyright 2023 The TensorFlow Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
"""Test Transformer model."""
import os
import re
import sys
import unittest
from absl import flags
from absl.testing import flagsaver
import tensorflow as tf, tf_keras
from tensorflow.python.eager import context # pylint: disable=ungrouped-imports
from official.legacy.transformer import misc
from official.legacy.transformer import transformer_main
FLAGS = flags.FLAGS
FIXED_TIMESTAMP = 'my_time_stamp'
WEIGHT_PATTERN = re.compile(r'weights-epoch-.+\.hdf5')
def _generate_file(filepath, lines):
with open(filepath, 'w') as f:
for l in lines:
f.write('{}\n'.format(l))
class TransformerTaskTest(tf.test.TestCase):
local_flags = None
def setUp(self): # pylint: disable=g-missing-super-call
temp_dir = self.get_temp_dir()
if TransformerTaskTest.local_flags is None:
misc.define_transformer_flags()
# Loads flags, array cannot be blank.
flags.FLAGS(['foo'])
TransformerTaskTest.local_flags = flagsaver.save_flag_values()
else:
flagsaver.restore_flag_values(TransformerTaskTest.local_flags)
FLAGS.model_dir = os.path.join(temp_dir, FIXED_TIMESTAMP)
FLAGS.param_set = 'tiny'
FLAGS.use_synthetic_data = True
FLAGS.steps_between_evals = 1
FLAGS.train_steps = 1
FLAGS.validation_steps = 1
FLAGS.batch_size = 4
FLAGS.max_length = 1
FLAGS.num_gpus = 1
FLAGS.distribution_strategy = 'off'
FLAGS.dtype = 'fp32'
self.model_dir = FLAGS.model_dir
self.temp_dir = temp_dir
self.vocab_file = os.path.join(temp_dir, 'vocab')
self.vocab_size = misc.get_model_params(FLAGS.param_set, 0)['vocab_size']
self.bleu_source = os.path.join(temp_dir, 'bleu_source')
self.bleu_ref = os.path.join(temp_dir, 'bleu_ref')
self.orig_policy = (
tf.compat.v2.keras.mixed_precision.global_policy())
def tearDown(self): # pylint: disable=g-missing-super-call
tf.compat.v2.keras.mixed_precision.set_global_policy(self.orig_policy)
def _assert_exists(self, filepath):
self.assertTrue(os.path.exists(filepath))
def test_train_no_dist_strat(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
t = transformer_main.TransformerTask(FLAGS)
t.train()
def test_train_save_full_model(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
FLAGS.save_weights_only = False
t = transformer_main.TransformerTask(FLAGS)
t.train()
def test_train_static_batch(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
FLAGS.distribution_strategy = 'one_device'
if tf.test.is_built_with_cuda():
FLAGS.num_gpus = 1
else:
FLAGS.num_gpus = 0
FLAGS.static_batch = True
t = transformer_main.TransformerTask(FLAGS)
t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_1_gpu_with_dist_strat(self):
FLAGS.distribution_strategy = 'one_device'
t = transformer_main.TransformerTask(FLAGS)
t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_fp16(self):
FLAGS.distribution_strategy = 'one_device'
FLAGS.dtype = 'fp16'
t = transformer_main.TransformerTask(FLAGS)
t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_2_gpu(self):
if context.num_gpus() < 2:
self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'
.format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2
FLAGS.param_set = 'base'
t = transformer_main.TransformerTask(FLAGS)
t.train()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_train_2_gpu_fp16(self):
if context.num_gpus() < 2:
self.skipTest(
'{} GPUs are not available for this test. {} GPUs are available'
.format(2, context.num_gpus()))
FLAGS.distribution_strategy = 'mirrored'
FLAGS.num_gpus = 2
FLAGS.param_set = 'base'
FLAGS.dtype = 'fp16'
t = transformer_main.TransformerTask(FLAGS)
t.train()
def _prepare_files_and_flags(self, *extra_flags):
# Make log dir.
if not os.path.exists(self.temp_dir):
os.makedirs(self.temp_dir)
# Fake vocab, bleu_source and bleu_ref.
tokens = [
"'<pad>'", "'<EOS>'", "'_'", "'a'", "'b'", "'c'", "'d'", "'a_'", "'b_'",
"'c_'", "'d_'"
]
tokens += ["'{}'".format(i) for i in range(self.vocab_size - len(tokens))]
_generate_file(self.vocab_file, tokens)
_generate_file(self.bleu_source, ['a b', 'c d'])
_generate_file(self.bleu_ref, ['a b', 'd c'])
# Update flags.
update_flags = [
'ignored_program_name',
'--vocab_file={}'.format(self.vocab_file),
'--bleu_source={}'.format(self.bleu_source),
'--bleu_ref={}'.format(self.bleu_ref),
]
if extra_flags:
update_flags.extend(extra_flags)
FLAGS(update_flags)
def test_predict(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags()
t = transformer_main.TransformerTask(FLAGS)
t.predict()
@unittest.skipUnless(tf.test.is_built_with_cuda(), 'requires GPU')
def test_predict_fp16(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
self._prepare_files_and_flags('--dtype=fp16')
t = transformer_main.TransformerTask(FLAGS)
t.predict()
def test_eval(self):
if context.num_gpus() >= 2:
self.skipTest('No need to test 2+ GPUs without a distribution strategy.')
if 'test_xla' in sys.argv[0]:
self.skipTest('TODO(xla): Make this test faster under XLA.')
self._prepare_files_and_flags()
t = transformer_main.TransformerTask(FLAGS)
t.eval()
if __name__ == '__main__':
tf.test.main()